{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy \n",
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "path_to_constants_folder ='PATH_TO_CONSTANTS_FOLDER'\n",
    "sys.path.append(path_to_constants_folder)\n",
    "# Your local path in here\n",
    "import constants as constants\n",
    "import seaborn as sns\n",
    "from matplotlib.backends.backend_pdf import PdfPages as PDF\n",
    "import datetime\n",
    "import warnings\n",
    "#warnings.filterwarnings('ignore')\n",
    "import logging\n",
    "import os\n",
    "import sys \n",
    "current_date = datetime.datetime.now().strftime('%Y-%m-%d')\n",
    "from collections import Counter\n",
    "from itertools import product\n",
    "\n",
    "\n",
    "names_latex = ['Total Events',\n",
    "     'Early Warning',\n",
    "     'Sync Warning',\n",
    "     'Late Warning',\n",
    "     'Soft Warning',\n",
    "     'Missed Outbreaks',\n",
    "     'Warning, activity increases',\n",
    "     'False Alarm']\n",
    "\n",
    "np.random.seed(2)\n",
    "\n",
    "\n",
    "## Path to dataset\n",
    "PATH_DATASET_EWS = 'PATH_TO_DATASET/{0}/' # leave /{0}/ at the end as it is used based on the geo option (state or county)\n",
    "# Where all experiment folders will be generated\n",
    "path_output_results = '/tmp'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_results_stacked(experiment_names):\n",
    "    \n",
    "    dfs = []\n",
    "    for experiment_name in experiment_names:\n",
    "        if 'lucas' in experiment_name:\n",
    "            df,_ = perform_analysis_lucas(experiment_name)\n",
    "        else:\n",
    "            df,_ = perform_analysis(experiment_name)\n",
    "        dfs.append( df.T[rows_ordered].T[[experiment_name]].copy() )\n",
    "    dfs = pd.concat(dfs, axis=1)\n",
    "    \n",
    "    return dfs\n",
    "\n",
    "def format_column(dfs_corrected, column_name, n_events):\n",
    "\n",
    "    dfs_corrected.loc['Total Events',column_name] ='{0} (100%)'.format(n_events)\n",
    "    m = dfs_corrected.loc['Early Warning',column_name]\n",
    "    dfs_corrected.loc['Early Warning',column_name] ='{0} ({1}%)'.format(dfs_corrected.loc['Early Warning',column_name], int(np.round((dfs_corrected.loc['Early Warning',column_name]/n_events)*100))  )\n",
    "    m += dfs_corrected.loc['Sync Warning',column_name]\n",
    "    dfs_corrected.loc['Sync Warning',column_name] ='{0} ({1}%)'.format(dfs_corrected.loc['Sync Warning',column_name], int(np.round((dfs_corrected.loc['Sync Warning',column_name]/n_events)*100))  )\n",
    "    m += dfs_corrected.loc['Late Warning',column_name]\n",
    "    dfs_corrected.loc['Late Warning',column_name] ='{0} ({1}%)'.format(dfs_corrected.loc['Late Warning',column_name], int(np.round((dfs_corrected.loc['Late Warning',column_name]/n_events)*100))  )\n",
    "    m += dfs_corrected.loc['Event and EWS close to threshold',column_name]\n",
    "    dfs_corrected.loc['Event and EWS close to threshold',column_name] ='{0} ({1}%)'.format(dfs_corrected.loc['Event and EWS close to threshold',column_name], int(np.round((dfs_corrected.loc['Event and EWS close to threshold',column_name]/n_events)*100))  )\n",
    "    dfs_corrected.loc['Missed Outbreaks',column_name] = '{0} ({1}%)'.format( n_events  - m, int(np.round(( (n_events - m)/n_events)*100))  ) \n",
    "    \n",
    "    return dfs_corrected\n",
    "\n",
    "\n",
    "\n",
    "def format_column_int(dfs_corrected, column_name, n_events):\n",
    "\n",
    "    dfs_corrected.loc['Total Events',column_name] =n_events\n",
    "    m = dfs_corrected.loc['Early Warning',column_name]\n",
    "    dfs_corrected.loc['Early Warning',column_name] = dfs_corrected.loc['Early Warning',column_name]\n",
    "    m += dfs_corrected.loc['Sync Warning',column_name]\n",
    "    dfs_corrected.loc['Sync Warning',column_name] =dfs_corrected.loc['Sync Warning',column_name]\n",
    "    m += dfs_corrected.loc['Late Warning',column_name]\n",
    "    dfs_corrected.loc['Late Warning',column_name] =dfs_corrected.loc['Late Warning',column_name]\n",
    "    m += dfs_corrected.loc['Event and EWS close to threshold',column_name]\n",
    "    dfs_corrected.loc['Event and EWS close to threshold',column_name] =dfs_corrected.loc['Event and EWS close to threshold',column_name]\n",
    "    dfs_corrected.loc['Missed Outbreaks',column_name] = n_events  - m\n",
    "    \n",
    "    return dfs_corrected\n",
    "\n",
    "\n",
    "\n",
    "def create_experiment_folders(experiment_name):\n",
    "    if os.path.exists(path_output_results + '/{0}/'.format(experiment_name)):\n",
    "        raise Exception('experiment exists already')\n",
    "\n",
    "    else:\n",
    "        os.mkdir(path_output_results + '/{0}/'.format(experiment_name))\n",
    "        os.mkdir(path_output_results + '/{0}/Figures'.format(experiment_name))\n",
    "        os.mkdir(path_output_results + '/{0}/performance_per_fit'.format(experiment_name))\n",
    "        os.mkdir(path_output_results + '/{0}/pickle'.format(experiment_name))\n",
    "        \n",
    "        \n",
    "def get_binary_timeseries(term, df_timeseries, df, verbose=False):\n",
    "    signal_activations = []\n",
    "    binary_timeseries_term = df_timeseries[['JHU_cases']].copy()*0\n",
    "    binary_timeseries_term.columns = [term]\n",
    "\n",
    "    for date_activation in df[(df['term'] == term) & (df['site_name'] == site_name) ].values.ravel()[2:]:\n",
    "        \n",
    "        if isinstance(date_activation, str):\n",
    "            if verbose: print(date_activation)\n",
    "            \n",
    "            if '-' in date_activation:\n",
    "                #datetime.datetime.strptime(df_xy.index.values[-1], '%d-%b-%y')\n",
    "                signal_activations.append(datetime.datetime.strptime(date_activation, '%d-%b-%y').strftime('%Y-%m-%d'))        \n",
    "                #signal_activations.append(date_activation)\n",
    "                binary_timeseries_term.loc[signal_activations[-1], term] = 1\n",
    "            \n",
    "    return binary_timeseries_term\n",
    "\n",
    "def get_training_test_dates(n_events_training,\n",
    "                            event_dates_gold_standard,\n",
    "                            timeseries_events_binary,\n",
    "                            forward_window=3,\n",
    "                           verbose=False):\n",
    "    \n",
    "    '''\n",
    "        Computes start date and end date for both training and test dataset\n",
    "    '''\n",
    "    \n",
    "    \n",
    "    if verbose:\n",
    "        print('Finding start date and end date for training and test datasets')\n",
    "        \n",
    "    \n",
    "    dates = list(timeseries_events_binary.index.values)\n",
    "    training_events = event_dates_gold_standard.iloc[:n_events_training]\n",
    "    test_events = event_dates_gold_standard.iloc[n_events_training]\n",
    "    date_end_training_dataset = training_events.iloc[-1]['date_end']\n",
    "    date_event = test_events['date_start']\n",
    "    ind = timeseries_events_binary.index.get_loc(date_event) + forward_window\n",
    "    if ind < timeseries_events_binary.shape[0]:\n",
    "        date_end_test_dataset = timeseries_events_binary.index.values.ravel()[ind]\n",
    "    else:\n",
    "        print('Event {0} close to boundary'.format(date_event))\n",
    "        date_end_test_dataset = timeseries_events_binary.index.values.ravel()[-1]\n",
    "    date_distance = timeseries_events_binary.index.get_loc(date_event) - timeseries_events_binary.index.get_loc(date_end_training_dataset)\n",
    "    if verbose:\n",
    "        print('date distance =>', date_distance)\n",
    "        print('date_event =>', date_event)\n",
    "    if date_distance < lookback_window_ews:\n",
    "        close_events.append((site_name, n_events_training))\n",
    "        date_start_test_dataset = timeseries_events_binary.index.values.ravel()[timeseries_events_binary.index.get_loc(date_event) - lookback_window_ews] \n",
    "    else:\n",
    "        date_start_test_dataset = date_end_training_dataset\n",
    "        \n",
    "    date_start_training_dataset = timeseries_events_binary.index.values.ravel()[0]\n",
    "    \n",
    "    return date_start_training_dataset, date_end_training_dataset, date_start_test_dataset, date_end_test_dataset, date_event\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def label_events(df_to_benchmark : pd.DataFrame,\n",
    "                   lookback_window=6,\n",
    "                 verbose=False,\n",
    "                 tp_late_tolerance = 0\n",
    "                   ):\n",
    "    \n",
    "    '''\n",
    "        Maps events to classification labels:\n",
    "        \n",
    "        True Positive => Predictor event precedes Target Event within lookback_window\n",
    "        False Positive => Predictor event activates but no target event\n",
    "        True Negative => --- \n",
    "        False Negative => Target Event activates but no predictor event precedes it\n",
    "        \n",
    "        \n",
    "        INPUT\n",
    "        _____\n",
    "        \n",
    "        df_to_benchmark : Pandas DataFrame\n",
    "            Dataframe with two columns 'x' and 'y' for predictor and target data correspondingly.\n",
    "            index is 'YYYY-mm-dd' date string.\n",
    "        \n",
    "        lookback_window : int\n",
    "            Number of time points to look for an event\n",
    "            \n",
    "        OUTPUT\n",
    "        ______\n",
    "        \n",
    "        events_labeled : Dict\n",
    "            Dictionary with the following structure:\n",
    "            \n",
    "            {\n",
    "                'tp': [] # List of tuples. each tuple contains (date_target_event, date_predictor_event),\n",
    "                'tn': [] # Empty list.\n",
    "                'fp': [] # List of dates when predictor event activated\n",
    "                'fn': [] # List of dates when target event activated\n",
    "            }\n",
    "            \n",
    "    '''\n",
    "    \n",
    "    events_labeled = {\n",
    "        'tp':[],\n",
    "        'tn':[],\n",
    "        'fp':[],\n",
    "        'fn':[]\n",
    "    }\n",
    "\n",
    "\n",
    "    x = df_to_benchmark['x'].ravel()*1\n",
    "    y = df_to_benchmark['y'].ravel()*1\n",
    "    \n",
    "    dates = list(df_to_benchmark.index.values)\n",
    "\n",
    "    # Filtering event endings\n",
    "    y[y==-1] = 0\n",
    "    x[x==-1] = 0\n",
    "\n",
    "    index_target_events = [i for i,v in enumerate(y) if v == 1]\n",
    "    index_predictor_events = [i for i,v in enumerate(x) if v == 1]\n",
    "    n_predictor_events = len(index_predictor_events)\n",
    "\n",
    "    if verbose: print('y events =>',index_target_events)\n",
    "    if verbose: print('y =>',y)\n",
    "    if verbose: print('x events =>',index_predictor_events)\n",
    "    if verbose: print('x =>',x)\n",
    "    \n",
    "    if verbose: print('Looping through target events =>')\n",
    "    for ind in index_target_events:\n",
    "        if verbose: print('Target event at => {0}'.format(ind))\n",
    "        if ind >= lookback_window: # Don't count events right next at start of boundary\n",
    "            \n",
    "            lower_ind = lower_boundary_correction(ind-lookback_window)\n",
    "            # If predictor event preceded or happened at the same time than target then its tp\n",
    "            if 1 in x[lower_ind:ind+tp_late_tolerance+1]:\n",
    "                predictor_ind = ind - lookback_window + list(x[lower_ind:ind+tp_late_tolerance+1]).index(1) # Location of first 1 within window\n",
    "                events_labeled['tp'].append( (dates[ind], dates[predictor_ind], ind-predictor_ind))\n",
    "                x[lower_ind:ind+1] = 0 # Erase from x to not count it any more\n",
    "            else:\n",
    "                events_labeled['fn'].append(dates[ind])\n",
    "        else:\n",
    "            if verbose: print('Event is right next to start of boundary')\n",
    "        \n",
    "            \n",
    "    index_events_predictor_left = [i for i,v in enumerate(x) if v == 1]\n",
    "    if verbose: print('index_events_predictor_left => ', index_events_predictor_left)\n",
    "    for ind in index_events_predictor_left:\n",
    "        events_labeled['fp'].append(dates[ind])\n",
    "        \n",
    "        \n",
    "    if verbose: print('events_labeled =>',events_labeled)\n",
    "    esum = len(index_target_events) + n_predictor_events - len(events_labeled['tp'])*2 -len(events_labeled['fp']) - len(events_labeled['tn']) - len(events_labeled['fn'])\n",
    "    \n",
    "    if verbose: print('event sum => {0}'.format(esum))\n",
    "    if verbose: print('n target events =>', len(index_target_events))\n",
    "    if verbose: print('n predictor events =>', n_predictor_events)\n",
    "    if verbose: print('tp =>',len(events_labeled['tp']))\n",
    "    if verbose: print('fp =>',len(events_labeled['fp']))\n",
    "    if verbose: print('tn =>',len(events_labeled['tn']))\n",
    "    if verbose: print('fn =>',len(events_labeled['fn']))\n",
    "          \n",
    "    if esum != 0:\n",
    "        pass\n",
    "        if verbose: print('SUM ERROR')\n",
    "    \n",
    "    return events_labeled\n",
    "\n",
    "def lower_boundary_correction(index):\n",
    "    # If index is below 0 then map to 0\n",
    "    if index < 0:\n",
    "        return 0\n",
    "    else:\n",
    "        return index\n",
    "    \n",
    "    \n",
    "    \n",
    "def events2metrics(labeled_events):\n",
    "    '''\n",
    "    INPUT\n",
    "    _____\n",
    "    \n",
    "    labeled_events : dict\n",
    "        Python dictionary containing information about tp, tn, fp and fn\n",
    "    \n",
    "    OUTPUT\n",
    "    ______\n",
    "    \n",
    "    metrics \n",
    "    returns a list of metrics in the following order  => tp, fp, tn, fn, tpr, precision, m1\n",
    "    \n",
    "    '''\n",
    "    \n",
    "    tp = len(labeled_events['tp'])\n",
    "    \n",
    "    earliness = [tup[2] for tup in labeled_events['tp']]\n",
    "    if len(earliness) == 0:\n",
    "        earliness = float('nan')\n",
    "    elif len(earliness) == 1:\n",
    "        earliness = earliness[0]\n",
    "    \n",
    "    fp = len(labeled_events['fp'])\n",
    "    tn = len(labeled_events['tn'])\n",
    "    fn = len(labeled_events['fn'])\n",
    "    sfp = len(labeled_events['sfp'])\n",
    "    \n",
    "    if tp > 0:\n",
    "        tpr = tp/(tp+fn)\n",
    "        precision = tp/(tp+fp)\n",
    "        m1 = tp/(tp+fn+fp)\n",
    "        m2 = tp/(tp+fn+sfp)\n",
    "        m3 = (m1 + m2)/2\n",
    "    else:\n",
    "        tpr = 0\n",
    "        precision = 0\n",
    "        m1 = 0\n",
    "        m2 = 0\n",
    "        m3 = 0\n",
    "    return [tp, fp, tn, fn, sfp, tpr, precision, m1, m2, m3, earliness]\n",
    "\n",
    "def generate_ROC_dataFrame(ews_timeseries, target_timeseries_binary, n_non_zero_weight_predictors=1, ensemble_thresholds = None):\n",
    "    '''\n",
    "        Computes 'tp', 'fp', 'tn', 'fn', 'tpr', 'precision', 'm1'\n",
    "        for different values of \"threshold\" for the EWS values.\n",
    "    '''\n",
    "    # ROC \n",
    "    if ensemble_thresholds is None:\n",
    "        ensemble_thresholds = []\n",
    "        for x in range(n_non_zero_weight_predictors+1):\n",
    "            ensemble_thresholds.append(mapped_sigmoid(x))\n",
    "    \n",
    "        \n",
    "    names_columns = ['tp', 'fp', 'tn', 'fn', 'sfp', 'tpr', 'precision', 'm1', 'm2', 'm3', 'earliness']\n",
    "    performance_per_threshold = []\n",
    "\n",
    "    for ensemble_threshold in ensemble_thresholds: \n",
    "        df_to_benchmark = pd.concat([ews_timeseries > ensemble_threshold, target_timeseries_binary], axis=1)\n",
    "        df_to_benchmark.columns = ['x','y']\n",
    "\n",
    "        events_ews = label_events(df_to_benchmark,\n",
    "                       lookback_window=4,\n",
    "                     verbose=False,\n",
    "                       )\n",
    "        events_ews = add_sfps(ews_timeseries, df_to_benchmark, events_ews)\n",
    "        performance_per_threshold.append(events2metrics(events_ews))\n",
    "        \n",
    "    roc_data = pd.DataFrame(performance_per_threshold, columns=names_columns, index=ensemble_thresholds)\n",
    "    roc_data.index.name =r'$\\beta$'\n",
    "    \n",
    "    return roc_data\n",
    "\n",
    "\n",
    "def add_sfps(ews_timeseries, df_to_benchmark, labeled_events, verbose=False):\n",
    "    '''\n",
    "        Detect chains of events within the EWS\n",
    "        Discrete EWS is composed of binary activations (0s and 1) of different streams of data.\n",
    "        These streams are aggregated and then merged using a moving average approach\n",
    "\n",
    "    '''\n",
    "    \n",
    "    if verbose:\n",
    "        print('ews_timeseries =>')\n",
    "        display(ews_timeseries)\n",
    "    \n",
    "    #---- Generating a hash table with the current dates\n",
    "    if verbose: print('labeled events before filtering =>\\n', labeled_events)\n",
    "    \n",
    "    ews_event_chains = find_chains(ews_timeseries.values.ravel()>0, gap_tolerance = 0)\n",
    "    \n",
    "    if verbose: print('EWS event chains =>\\n', ews_event_chains)\n",
    "    \n",
    "    event_array = df_to_benchmark[['x']].copy()*0\n",
    "    event_array.columns = ['n_event']\n",
    "    v = event_array.values.ravel()\n",
    "    \n",
    "    for n_event, index_chain in enumerate(ews_event_chains):\n",
    "        for ind in index_chain:\n",
    "            v[ind] = n_event+1 #(+1 to avoid starting from 0)\n",
    "    event_array['n_event'] = v\n",
    "\n",
    "    #---- Mark chains that contain a TP date\n",
    "    n_event_chains = len(ews_event_chains)\n",
    "    tp_chains = []\n",
    "    for tp_date_target, tp_date_predictor, _ in labeled_events['tp']:\n",
    "        n_event = event_array.loc[tp_date_predictor].values[0]\n",
    "        if n_event != 0 and n_event not in tp_chains:\n",
    "            tp_chains.append(n_event)\n",
    "    new_fps = []\n",
    "    already_fp = []\n",
    "\n",
    "    #---- For each fp date, check if they fall within the TP chain or if they're repeated FPs events which are already accounted for\n",
    "    for fp_date in labeled_events['fp']:\n",
    "        n_event = event_array.loc[fp_date].values[0]\n",
    "        if verbose:\n",
    "            print('\\n\\n')\n",
    "            print('Checking fp_date =>', fp_date)\n",
    "            print('n_event =>', n_event)\n",
    "\n",
    "        if n_event in tp_chains:\n",
    "            if verbose: print('event in tp_chains, continuing, tp_chains =>', tp_chains)\n",
    "            continue\n",
    "        elif n_event not in tp_chains:\n",
    "\n",
    "            if n_event in already_fp:\n",
    "                #print('event in already_fp, continuing, already_fp =>', already_fp)\n",
    "                continue\n",
    "            else:\n",
    "                #print('new fp event, continuing, already_fp =>', already_fp)\n",
    "                new_fps.append(fp_date)\n",
    "                already_fp.append(n_event)\n",
    "\n",
    "    #---- Add new definition\n",
    "    labeled_events['sfp'] = new_fps\n",
    "    \n",
    "    return labeled_events\n",
    "\n",
    "\n",
    "def sigmoid(x):\n",
    "    return  1/(1+np.exp(-x))\n",
    "def mapped_sigmoid(x):\n",
    "    return ((1/(1+np.exp(-x))) - .5)/.5\n",
    "\n",
    "def event_analysis(events,\n",
    "                   event_timeseries,\n",
    "                   voting_threshold,\n",
    "                   lookback_window,\n",
    "                   late_window,\n",
    "                   percent_to_threshold,\n",
    "                   verbose=False,\n",
    "                   threshold=1,\n",
    "                  ):\n",
    "    \n",
    "    \n",
    "    '''\n",
    "    As a way to quantify the effectiveness of our methodology, and similar to our alternative approach,\n",
    "    we defined the following categories:\n",
    "\n",
    "    Early Warning: An EWS activation occurs earlier in time in comparison to our target (at most 6 weeks earlier).\n",
    "    Synchronous Activation: An EWS activation occurs at the same date that a target event is identified.\n",
    "    Late Activation: An EWS activation is registered after a target event is identified (at most 2 weeks later)\n",
    "    Activity Increase but no Event: An EWS event is captured and a subsequent increase in target activity is also observed,\n",
    "    but not enough to be considered an epidemic event (same time window as Early Warning and Late Activation).\n",
    "    Missed event, EWS close to activation: A target event was detected \n",
    "    Missed event, no EWS activity: A target event was detected but no EWS activation was observed\n",
    "    False Alarm: An EWS activation is observed, but no event or activity increase is detected within the target.\n",
    "\n",
    "    '''\n",
    "    \n",
    "\n",
    "    if verbose:\n",
    "        print('Event =>', events)\n",
    "\n",
    "\n",
    "    # True Positive\n",
    "    n_early = 0\n",
    "    n_sync = 0\n",
    "    n_late = 0\n",
    "\n",
    "    # FP\n",
    "    n_warning_increase_but_no_outbreak = 0\n",
    "    n_false_alarm = 0\n",
    "\n",
    "    # FN\n",
    "    n_missed = 0\n",
    "    n_missed_close_threshold = 0\n",
    "\n",
    "    # For each  event in the labeled events\n",
    "    # If TP and > 0, then early_warning++\n",
    "    # If TP and == 0, then synchronous_activation++\n",
    "    # If TP and < 0, then late_activation++\n",
    "    \n",
    "    event_date_labels = []\n",
    "    \n",
    "    \n",
    "\n",
    "    if verbose:print('\\n\\n\\n Analyzing True Positives')\n",
    "        \n",
    "        \n",
    "    if len(events['tp']) + len(events['fn']) > 1:\n",
    "        print('Event with more than 1')\n",
    "        print(events)\n",
    "        \n",
    "    for date_target_activation, date_predictor_activation, n_weeks  in events['tp']:\n",
    "        \n",
    "        ind_tp = event_timeseries.index.get_loc(date_target_activation)\n",
    "        \n",
    "        if verbose:\n",
    "            print('target act => ', date_target_activation)\n",
    "            print('predictor act => ', date_predictor_activation)\n",
    "            print('n_weeks_diff =>', n_weeks)\n",
    "\n",
    "        if n_weeks > 0:\n",
    "            n_early += 1\n",
    "            if verbose: print('Adding early activation')\n",
    "            event_date_labels.append(['early',date_target_activation,date_predictor_activation])\n",
    "        elif n_weeks == 0:\n",
    "            n_sync += 1\n",
    "            if verbose: print('Adding sync activation')\n",
    "            event_date_labels.append(['sync',date_target_activation,date_predictor_activation])\n",
    "        elif n_weeks < 0:\n",
    "            n_late += 1\n",
    "            if verbose: print('Adding late activation')\n",
    "            event_date_labels.append(['late',date_target_activation,date_predictor_activation])\n",
    "\n",
    "    # if FN, then missed_event++\n",
    "    if verbose: print('\\n\\n\\n Analyzing False Negatives')\n",
    "\n",
    "    for date_target_activation  in events['fn']:\n",
    "        if verbose: print('\\n\\nDate Target Event =>', date_target_activation)\n",
    "        ind = event_timeseries.index.get_loc(date_target_activation)\n",
    "        ind_tp = event_timeseries.index.get_loc(date_target_activation)\n",
    "\n",
    "        if 1 in ((event_timeseries.iloc[ind-lookback_window:ind+late_window]['ews'] >= percent_to_threshold*voting_threshold)*1).values:\n",
    "\n",
    "            if verbose:\n",
    "                print('False Negative classified as, Missed Event, EWS close to threshold (%{0} or closer)'.format(percent_to_threshold))\n",
    "                print('EWS Timeseries =>',event_timeseries.iloc[ind-lookback_window:ind+late_window+1]['ews'])\n",
    "                print('Voting Threshold =>', voting_threshold)\n",
    "\n",
    "            n_missed_close_threshold += 1  \n",
    "            event_date_labels.append(['close',date_target_activation,''])\n",
    "\n",
    "        else:\n",
    "\n",
    "            if verbose:\n",
    "                print('False Negative classified as missed event without relevant EWS activity')\n",
    "            n_missed +=1\n",
    "            event_date_labels.append(['missed',date_target_activation,''])\n",
    "\n",
    "    # if  FP and lambda > 1 within time window, then Activity increase but no Event\n",
    "    # Add earliness\n",
    "\n",
    "    if verbose:\n",
    "        print('\\n\\n\\n')\n",
    "        print('Analyzing False Positives')\n",
    "    for date_predictor_activation in events['sfp']:\n",
    "        ind = event_timeseries.index.get_loc(date_predictor_activation)\n",
    "        \n",
    "        \n",
    "        # If false alarm occurred after event started, then ignore\n",
    "        if ind - ind_tp > 0:\n",
    "            if verbose: print('FP {0} occurred after tp, ignoring'.format(date_predictor_activation))\n",
    "            continue\n",
    "        \n",
    "        \n",
    "        if verbose:\n",
    "            print('\\n\\n')\n",
    "            print('Date =>', date_predictor_activation)\n",
    "            print('Activity =>', event_timeseries.iloc[ind-lookback_window:ind+late_window+1][['target_original', 'target_lambda','ews', 'ews_thresholded']])\n",
    "\n",
    "        if 1 in (event_timeseries.iloc[ind:ind+lookback_window+1][['target_lambda']].values >= threshold)*1:\n",
    "\n",
    "            if verbose: print('Event classified as warning associated to target increase (but no event)')\n",
    "            n_warning_increase_but_no_outbreak += 1\n",
    "            event_date_labels.append(['soft event','',date_predictor_activation])\n",
    "\n",
    "        else:\n",
    "            if verbose: print('Event classified as false Alarm')\n",
    "            n_false_alarm += 1\n",
    "            event_date_labels.append(['false alarm','',date_predictor_activation])\n",
    "\n",
    "    return [n_early, n_sync, n_late, n_warning_increase_but_no_outbreak, n_missed_close_threshold, n_false_alarm, n_missed], event_date_labels\n",
    "\n",
    "def perform_analysis(experiment_name, site_names = constants.site_names[50:], threshold=1, filter_date = None, date='2021-12-01'):\n",
    "    verbose = False\n",
    "    # Benchmark Table\n",
    "    analysis_col_names = ['Early Warning', 'Sync Warning', 'Late Warning', 'Warning Associated to Activity but no Event', 'Event and EWS close to threshold', 'False Alarm', 'Missed Outbreaks']\n",
    "\n",
    "    site_analyses = []\n",
    "    dataframe_analyses = []\n",
    "    earliness_list = []\n",
    "\n",
    "\n",
    "    #---- Function Params\n",
    "    lookback_window = 6\n",
    "    late_window = 2\n",
    "    percent_to_threshold = .5\n",
    "\n",
    "\n",
    "    for site_name in site_names:\n",
    "\n",
    "        files = sorted(os.listdir(path_output_results + '/{0}/pickle'.format(experiment_name)))\n",
    "        event_pickle_names = [f for f in files if '{0}'.format(site_name) == f[:len(site_name)]]\n",
    "        event_pickle_names\n",
    "        \n",
    "        analyses = []\n",
    "        for filename_pickle in event_pickle_names:\n",
    "\n",
    "            try:\n",
    "                #---- Loading Event Data\n",
    "                fit_data = pickle.load(open(path_output_results +'/{0}/pickle/{1}'.format(experiment_name, filename_pickle),'rb'))\n",
    "                events = fit_data['performance'][late_window]\n",
    "                \n",
    "                \n",
    "                if filter_date:\n",
    "                    \n",
    "                    if len(events['tp']) > 0:\n",
    "                        \n",
    "                        if filter_date == 'after':\n",
    "                            condition = events['tp'][0][0] < date #filter_date\n",
    "                        elif filter_date == 'before':\n",
    "                            condition = events['tp'][0][0] >= date\n",
    "                            \n",
    "                        \n",
    "                        if condition:\n",
    "                            print('skipping date =>', events['tp'][0])\n",
    "                            continue\n",
    "                    \n",
    "                    elif len(events['fn']) > 0:\n",
    "                        \n",
    "                        if filter_date == 'after':\n",
    "                            condition = events['fn'][0] < '2021-12-01' #filter_date\n",
    "                        elif filter_date == 'before':\n",
    "                            condition = events['fn'][0] >= '2021-12-01' \n",
    "                            \n",
    "                        if  condition: \n",
    "                            print('skipping date =>', events['fn'][0])\n",
    "                            continue\n",
    "                    \n",
    "                \n",
    "                for tup in events['tp']:\n",
    "                    earliness_list.append((site_name, tup[0], tup[1], tup[2]))\n",
    "                \n",
    "                event_timeseries = fit_data['timeseries']\n",
    "                \n",
    "                \n",
    "                if verbose:\n",
    "                    print('events =>', events)\n",
    "\n",
    "                if isinstance(fit_data['model'], EWS_Discrete_Ensemble):\n",
    "                    voting_threshold = fit_data['model'].voting_threshold\n",
    "                elif isinstance(fit_data['model'], EWS_Discrete):\n",
    "                    voting_threshold = fit_data['model'].beta\n",
    "\n",
    "\n",
    "                #----\n",
    "                analysis, labs = event_analysis(events,\n",
    "                                   event_timeseries,\n",
    "                                   voting_threshold,\n",
    "                                   lookback_window,\n",
    "                                   late_window,\n",
    "                                   percent_to_threshold,\n",
    "                                   verbose=verbose,\n",
    "                                    threshold=threshold\n",
    "                                                \n",
    "                                  )\n",
    "                analyses.append(analysis)\n",
    "                labs = pd.DataFrame(labs, columns = ['Event Type', 'Target Date', 'Predictor Date'])\n",
    "                labs['site_name'] = site_name\n",
    "                labs['n_events_training'] = filename_pickle[-5] \n",
    "                dataframe_analyses.append(labs.copy())\n",
    "            except Exception as t:\n",
    "                print('Error with {0}'.format(site_name))\n",
    "                print(t)\n",
    "\n",
    "        analyses = pd.DataFrame(analyses,columns=analysis_col_names)\n",
    "        #analyses.index.name = site_name\n",
    "        site_analyses.append(analyses)\n",
    "\n",
    "    dataframe_analyses = pd.concat(dataframe_analyses, axis=0)\n",
    "    df = pd.concat(site_analyses, axis=0).sum().T.to_frame()\n",
    "    df.columns = [experiment_name]\n",
    "    \n",
    "    \n",
    "    list_sum =['Early Warning',\n",
    "    'Sync Warning',\n",
    "    'Late Warning',\n",
    "    'Event and EWS close to threshold',\n",
    "    'Missed Outbreaks',\n",
    "    ]\n",
    "\n",
    "    list_not_percent = ['Warning Associated to Activity but no Event', 'False Alarm']\n",
    "    total_events = 0\n",
    "\n",
    "\n",
    "    for loc in list_sum:\n",
    "        total_events += df[[experiment_name]].loc[loc].values\n",
    "\n",
    "    df = df.T\n",
    "    df['Total Events'] = total_events\n",
    "    df = df.T\n",
    "    df['percent']=((df[experiment_name].values.ravel()/total_events)*100)\n",
    "    for loc in list_not_percent:\n",
    "        df.loc[loc,'percent'] = float('NaN')\n",
    "\n",
    "\n",
    "    text_vals = []\n",
    "\n",
    "    for i, row in df.iterrows():\n",
    "        if np.isnan(np.round(row['percent'], 1)):\n",
    "            text_vals.append('{0}'.format(row[experiment_name]))\n",
    "\n",
    "        else:\n",
    "            text_vals.append('{0} ({1}%)'.format(row[experiment_name], np.round(row['percent'], 1)))\n",
    "\n",
    "    df['Results'] = text_vals\n",
    "        \n",
    "    return df, earliness_list\n",
    "\n",
    "def analysis_matrix(experiment_name, site_names = constants.site_names[50:], verbose=True, threshold = 1):\n",
    "    \n",
    "    #verbose = False\n",
    "    # Benchmark Table\n",
    "    analysis_col_names = ['Early Warning', 'Sync Warning', 'Late Warning', 'Warning Associated to Activity but no Event', 'Event and EWS close to threshold', 'False Alarm', 'Missed Outbreaks']\n",
    "\n",
    "    site_analyses = []\n",
    "    dataframe_analyses = []\n",
    "    earliness_list = []\n",
    "\n",
    "\n",
    "    #---- Function Params\n",
    "    lookback_window = 6\n",
    "    late_window = 2\n",
    "    percent_to_threshold = .5\n",
    "    \n",
    "    \n",
    "    scores = {}\n",
    "    \n",
    "\n",
    "\n",
    "    for site_name in site_names:\n",
    "        scores[site_name] = []\n",
    "\n",
    "        files = sorted(os.listdir(path_output_results + '/{0}/pickle'.format(experiment_name)))\n",
    "        event_pickle_names = [f for f in files if '{0}'.format(site_name) == f[:len(site_name)]]\n",
    "        event_pickle_names\n",
    "        \n",
    "        analyses = []\n",
    "        for i, filename_pickle in enumerate(event_pickle_names):\n",
    "            \n",
    "            if verbose: print(filename_pickle)\n",
    "\n",
    "            try:\n",
    "                #---- Loading Event Data\n",
    "                fit_data = pickle.load(open(path_output_results+'/{0}/pickle/{1}'.format(experiment_name, filename_pickle),'rb'))\n",
    "                events = fit_data['performance'][late_window]\n",
    "                \n",
    "                if verbose: print(events)\n",
    "                \n",
    "                for tup in events['tp']:\n",
    "                    earliness_list.append((site_name, tup[0], tup[1], tup[2]))\n",
    "                \n",
    "                event_timeseries = fit_data['timeseries']\n",
    "                \n",
    "                \n",
    "                if verbose:\n",
    "                    print('events =>', events)\n",
    "\n",
    "                if isinstance(fit_data['model'], EWS_Discrete_Ensemble):\n",
    "                    voting_threshold = fit_data['model'].voting_threshold\n",
    "                elif isinstance(fit_data['model'], EWS_Discrete):\n",
    "                    voting_threshold = fit_data['model'].beta\n",
    "\n",
    "\n",
    "                #----\n",
    "                analysis, labs = event_analysis(events,\n",
    "                                   event_timeseries,\n",
    "                                   voting_threshold,\n",
    "                                   lookback_window,\n",
    "                                   late_window,\n",
    "                                   percent_to_threshold,\n",
    "                                   verbose=verbose,\n",
    "                                                threshold=threshold\n",
    "                                  )\n",
    "                \n",
    "                \n",
    "                if verbose:print('ANALYSIS =>',analysis)\n",
    "                \n",
    "                analyses.append(analysis + [site_name, i])\n",
    "                labs = pd.DataFrame(labs, columns = ['Event Type', 'Target Date', 'Predictor Date'])\n",
    "                labs['site_name'] = site_name\n",
    "                labs['n_events_training'] = filename_pickle[-5] \n",
    "                dataframe_analyses.append(labs.copy())\n",
    "            except Exception as t:\n",
    "                print('Error with {0}'.format(site_name))\n",
    "                print(t)\n",
    "\n",
    "        analyses = pd.DataFrame(analyses,columns=analysis_col_names + ['site_name', 'n_event'])\n",
    "        #analyses.index.name = site_name\n",
    "        site_analyses.append(analyses)\n",
    "\n",
    "    dataframe_analyses = pd.concat(dataframe_analyses, axis=0)\n",
    "    df = pd.concat(site_analyses, axis=0)\n",
    "    \n",
    "    \n",
    "        \n",
    "    return df\n",
    "\n",
    "def gen_hit_miss_analysis(sub_df, n_max_events):\n",
    "    hit_miss_analysis = []\n",
    "    for i, row in sub_df.iterrows():\n",
    "        if row['Early Warning'] == 1:\n",
    "            hit_miss_analysis.append(0)\n",
    "        elif row['Sync Warning'] == 1:\n",
    "            hit_miss_analysis.append(1)\n",
    "        elif row['Late Warning'] == 1:\n",
    "            hit_miss_analysis.append(2)\n",
    "        elif row['Event and EWS close to threshold'] == 1:\n",
    "            hit_miss_analysis.append(3)\n",
    "        elif row['Missed Outbreaks'] == 1:\n",
    "            hit_miss_analysis.append(4)\n",
    "    return hit_miss_analysis + [float('nan')]*(n_max_events - len(hit_miss_analysis))\n",
    "\n",
    "# function add timeseries to dataframe(df1, df2) # both have yyy-mm-dd\n",
    "\n",
    "\n",
    "class EWS_Discrete:\n",
    "    def __init__(self,\n",
    "                max_n_predictors = 6,\n",
    "                threshold_metric = .6,\n",
    "                 lookback_window = 6,\n",
    "                metrics_to_optimize = ['tp', 'fp'],\n",
    "                metric_to_optimize_beta = 'm2',\n",
    "                 sma_moving_window = 3,\n",
    "                ):\n",
    "        '''\n",
    "            INPUT\n",
    "            _____\n",
    "            \n",
    "            max_n_predictors : int (default 6)\n",
    "                Sets a threshold to the number of input variables to use for the ews.\n",
    "            \n",
    "            threshold_metric : float (default .6)\n",
    "                After sorting, filters predictors that meet the threshold criteria for\n",
    "                m1\n",
    "            \n",
    "            lookback_window : int (default 4)\n",
    "                Earliness to look for events\n",
    "            \n",
    "            metrics_to_optimize = List of str\n",
    "                Available metrics => tpr, precision, tp, fp, sfp, fn, m1\n",
    "                'tpr' : tp/(tp+fp)\n",
    "                'm1' : tp/(tp+fn+fp)\n",
    "                'm2' : tp/(tp+fn+sfp)\n",
    "                'm3' : ( m1 + m2 )/2\n",
    "        '''\n",
    "        \n",
    "        self.weights = []\n",
    "        self.timeseries_selected = []\n",
    "        self.individual_proxy_scores = []\n",
    "        self.params = {\n",
    "            'max_n_predictors' : max_n_predictors,\n",
    "            'threshold_metric' : threshold_metric,\n",
    "            'metrics_to_optimize' :metrics_to_optimize,\n",
    "            'lookback_window': lookback_window,\n",
    "            'metric_to_optimize_beta':metric_to_optimize_beta,\n",
    "        }\n",
    "        self.sma_moving_window = sma_moving_window\n",
    "        self.beta = None\n",
    "        #print('EWS for discrete activations created')\n",
    "    \n",
    "    def fit(self,\n",
    "                X_training,\n",
    "                Y_training,\n",
    "                verbose = False\n",
    "                 ):\n",
    "        ''' Training function V1 for EWS_discrete.\n",
    "        1.-Computes TP, FP, FN for each event\n",
    "        2.-Selects top n predictors\n",
    "        3.-Squashes them into a probability using a Sigmoid function\n",
    "\n",
    "        Metrics to optimize:\n",
    "            'tpr' : tp/(tp+fp)\n",
    "            'm1' : tp/(tp+fn+fp)\n",
    "            'm2': tp/(tp+fn+sfp) #SFP are filtered false positives for the EWS\n",
    "        '''\n",
    "        \n",
    "        if verbose: print('fitting')\n",
    "        \n",
    "        max_n_predictors = self.params['max_n_predictors']\n",
    "        threshold_metric = self.params['threshold_metric']\n",
    "        metrics_to_optimize = self.params['metrics_to_optimize']\n",
    "        lookback_window = self.params['lookback_window']\n",
    "        metric_to_optimize_beta = self.params['metric_to_optimize_beta']\n",
    "        weights = []\n",
    "        total_n_predictors = X_training.shape[1]\n",
    "        predictor_individual_performance = {\n",
    "            'name_predictor':[],\n",
    "            'tp':[],\n",
    "            'fp':[],\n",
    "            'sfp':[],\n",
    "            'tn':[],\n",
    "            'fn':[],\n",
    "            'tpr':[],\n",
    "            'precision':[],\n",
    "            'm1':[],\n",
    "            'm2':[],\n",
    "            'm3':[],\n",
    "        }\n",
    "        \n",
    "\n",
    "        dict_metric_sort_ascending = {\n",
    "            'name_predictor':[],\n",
    "            'tp':False,\n",
    "            'fp':True,\n",
    "            'sfp':True,\n",
    "            'tn':False,\n",
    "            'fn':True,\n",
    "            'tpr':False,\n",
    "            'precision':False,\n",
    "            'm1':False,\n",
    "            'm2':False,\n",
    "            'm3':False,\n",
    "        }\n",
    "\n",
    "        sort_ascending = []\n",
    "        for metric in metrics_to_optimize:\n",
    "            sort_ascending.append(dict_metric_sort_ascending[metric])\n",
    "\n",
    "        for name_predictor in list(X_training):\n",
    "\n",
    "            #---- Merge the predictor events alongside the target events\n",
    "            df_to_benchmark = pd.concat([ X_training[[name_predictor]].copy(), Y_training.copy()], axis=1)\n",
    "            df_to_benchmark.columns = ['x', 'y']\n",
    "\n",
    "            events_labeled = label_events(df_to_benchmark,\n",
    "                           lookback_window=lookback_window,\n",
    "                            )\n",
    "            \n",
    "            events_labeled['sfp'] = []\n",
    "            [tp, fp, tn, fn, sfp, tpr, precision, m1, m2, m3, earliness] = events2metrics(events_labeled)\n",
    "            \n",
    "            predictor_individual_performance['name_predictor'].append(name_predictor)\n",
    "            predictor_individual_performance['tp'].append(tp)\n",
    "            predictor_individual_performance['fp'].append(fp)\n",
    "            predictor_individual_performance['sfp'].append(sfp)\n",
    "            predictor_individual_performance['tn'].append(tn)\n",
    "            predictor_individual_performance['fn'].append(fn)\n",
    "            predictor_individual_performance['tpr'].append(tpr)\n",
    "            predictor_individual_performance['precision'].append(precision)\n",
    "            predictor_individual_performance['m1'].append(m1)\n",
    "            predictor_individual_performance['m2'].append(m2)\n",
    "            predictor_individual_performance['m3'].append(m3)\n",
    "\n",
    "        predictor_individual_performance = pd.DataFrame(predictor_individual_performance)\n",
    "        predictor_individual_performance.set_index('name_predictor', inplace=True)  \n",
    "        if verbose: print('fitting phase 2')\n",
    "        #---- PREDICTOR SELECTION + Decision Threshold \n",
    "        \n",
    "        # VERSION 1 \n",
    "        #-------- Pick the top-n based on optimizing metrics such as TPR\n",
    "        predictor_individual_performance.sort_values(metrics_to_optimize, ascending=sort_ascending, inplace=True)\n",
    "        \n",
    "        n_predictors_good_fit = predictor_individual_performance[predictor_individual_performance['tpr'] >= threshold_metric].shape[0]\n",
    "        \n",
    "        if n_predictors_good_fit == 0:\n",
    "            #print('No predictor found to have a value above the {0} threshold for {1}'.format(threshold_metric, metrics_to_optimize))\n",
    "            return\n",
    "        \n",
    "        if max_n_predictors is not None:\n",
    "            \n",
    "            if n_predictors_good_fit < max_n_predictors:\n",
    "                warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
    "                n_predictors = n_predictors_good_fit\n",
    "            else:\n",
    "                n_predictors = max_n_predictors\n",
    "            \n",
    "        else:\n",
    "            n_predictors = n_predictors_good_fit\n",
    "\n",
    "                \n",
    "\n",
    "        \n",
    "          \n",
    "                \n",
    "        \n",
    "        self.weights = dict(zip(list(predictor_individual_performance.index.values), [1]*n_predictors + [0]*(total_n_predictors-n_predictors) ))\n",
    "        self.individual_proxy_scores = predictor_individual_performance\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        #-------- Select threshold based on highest m1 (max tpr and min fp)\n",
    "        if verbose: print('fitting phase 3')\n",
    "        \n",
    "        n_non_zero_weight_predictors = (np.array(list(self.weights.values())) > 0).sum()\n",
    "        \n",
    "        \n",
    "        #Getting EWS timeseries using predict\n",
    "        ews_predictions = self.predict(X_training, beta=0)[['ews']]\n",
    "        \n",
    "        if n_non_zero_weight_predictors == 1:\n",
    "            \n",
    "            if verbose: print('model for single signal')\n",
    "            # If there is only one predictor, the EWS timeseries can only output .45 at most\n",
    "            self.beta = mapped_sigmoid(1)*.99\n",
    "            self.sma_moving_window = 1\n",
    "        \n",
    "            \n",
    "        else:\n",
    "            \n",
    "            df_to_benchmark = pd.concat([ews_predictions, Y_training], axis=1)\n",
    "            df_to_benchmark.columns = ['x','y']\n",
    "            \n",
    "            \n",
    "            roc_data = generate_ROC_dataFrame(ews_predictions, Y_training, n_non_zero_weight_predictors)\n",
    "            roc_data.sort_values(metric_to_optimize_beta, inplace=True, ascending=dict_metric_sort_ascending[metric_to_optimize_beta])\n",
    "            self.roc_data = roc_data\n",
    "    \n",
    "            self.beta = roc_data.index.values[0]\n",
    "            \n",
    "            \n",
    "            \n",
    "            if self.beta < .45:\n",
    "                if verbose: print('BETA 0')\n",
    "                #self.beta = mapped_sigmoid(n_predictors_good_fit/2)\n",
    "                self.beta = mapped_sigmoid(n_non_zero_weight_predictors/2)\n",
    "        \n",
    "        \n",
    "        if verbose:\n",
    "            print('Selected beta => ', self.beta)\n",
    "        \n",
    "        df_to_benchmark = pd.concat([ews_predictions > self.beta, Y_training], axis=1)\n",
    "        df_to_benchmark.columns = ['x','y']\n",
    "        self.events_training_current_beta = label_events(df_to_benchmark,\n",
    "                       lookback_window=4,\n",
    "                     verbose=False,\n",
    "                       )\n",
    "        \n",
    "        if verbose: print('fitting phase 4')\n",
    "        \n",
    "        return\n",
    "    \n",
    "    \n",
    "    def predict(self,X_predict,sma_moving_window=None, beta=None):\n",
    "        '''\n",
    "            Prediction function for EWS.\n",
    "            Scales the timeseries based on the weights obtained through the fit function\n",
    "            merges them (aggregate), applies a moving average over the aggregated timeseries\n",
    "            and finally maps the timeseries using a sigmoid function with a domain alteration.\n",
    "\n",
    "            INPUT\n",
    "            _____\n",
    "\n",
    "            X_predict : Pandas DataFrame\n",
    "                Dataset containing the events of each timeseries.\n",
    "\n",
    "            sma_moving_window : int (default 3)\n",
    "                Number of points to look for events retrospectively at time t.\n",
    "\n",
    "\n",
    "            OUTPUT\n",
    "            ______\n",
    "\n",
    "            ews_timeseries : Np.array\n",
    "                Each value represents\n",
    "\n",
    "        '''\n",
    "        if beta is None:\n",
    "            beta = self.beta\n",
    "        \n",
    "        if sma_moving_window:\n",
    "            pass\n",
    "        else:\n",
    "            sma_moving_window = self.sma_moving_window\n",
    "        \n",
    "        \n",
    "        weights = self.weights\n",
    "        \n",
    "        # Sum events together\n",
    "        merged_timeseries = np.zeros(X_predict.shape[0])\n",
    "        for col_name, weight in weights.items():\n",
    "            if weight != 0:\n",
    "                try:\n",
    "                    v = X_predict[col_name].values*weight\n",
    "                    v[v<0]=0 #Removing end of events (which are marked as -1)\n",
    "                    merged_timeseries += v\n",
    "                except Exception as t:\n",
    "                    logging.Exception('Missing predict : {0}'.format(col_name))\n",
    "                    print('Missing predict : {0}'.format(col_name))\n",
    "\n",
    "        # Use restrospective moving window to sum the past activities\n",
    "        ews_timeseries = np.zeros_like(merged_timeseries)\n",
    "        for i in range(sma_moving_window, len(ews_timeseries)):\n",
    "            ews_timeseries[i] = mapped_sigmoid(merged_timeseries[i-sma_moving_window+1:i+1].sum())\n",
    "        ews_timeseries[:sma_moving_window] = float('nan')\n",
    "        ews_timeseries = pd.DataFrame({'ews':list(ews_timeseries), 'date':list(X_predict.index.values) })\n",
    "        ews_timeseries.set_index('date', inplace=True)\n",
    "        #print(ews_timeseries)\n",
    "        #print(self.beta)\n",
    "        ews_timeseries['ews_thresholded']  = ews_timeseries['ews'].values >= beta\n",
    "        return ews_timeseries\n",
    "\n",
    "class EWS_Discrete_no_restriction:\n",
    "    def __init__(self,\n",
    "                max_n_predictors = 6,\n",
    "                threshold_metric = .6,\n",
    "                 lookback_window = 6,\n",
    "                metrics_to_optimize = ['tp', 'fp'],\n",
    "                metric_to_optimize_beta = 'm2'):\n",
    "        '''\n",
    "            INPUT\n",
    "            _____\n",
    "            \n",
    "            max_n_predictors : int (default 6)\n",
    "                Sets a threshold to the number of input variables to use for the ews.\n",
    "            \n",
    "            threshold_metric : float (default .6)\n",
    "                After sorting, filters predictors that meet the threshold criteria for\n",
    "                m1\n",
    "            \n",
    "            lookback_window : int (default 4)\n",
    "                Earliness to look for events\n",
    "            \n",
    "            metrics_to_optimize = List of str\n",
    "                Available metrics => tpr, precision, tp, fp, sfp, fn, m1\n",
    "                'tpr' : tp/(tp+fp)\n",
    "                'm1' : tp/(tp+fn+fp)\n",
    "                'm2' : tp/(tp+fn+sfp)\n",
    "                'm3' : ( m1 + m2 )/2\n",
    "        '''\n",
    "        \n",
    "        self.weights = []\n",
    "        self.timeseries_selected = []\n",
    "        self.individual_proxy_scores = []\n",
    "        self.params = {\n",
    "            'max_n_predictors' : max_n_predictors,\n",
    "            'threshold_metric' : threshold_metric,\n",
    "            'metrics_to_optimize' :metrics_to_optimize,\n",
    "            'lookback_window': lookback_window,\n",
    "            'metric_to_optimize_beta':metric_to_optimize_beta,\n",
    "        }\n",
    "        self.beta = None\n",
    "        #print('EWS for discrete activations created')\n",
    "    \n",
    "    def fit(self,\n",
    "                X_training,\n",
    "                Y_training,\n",
    "                verbose = False\n",
    "                 ):\n",
    "        ''' Training function V1 for EWS_discrete.\n",
    "        1.-Computes TP, FP, FN for each event\n",
    "        2.-Selects top n predictors\n",
    "        3.-Squashes them into a probability using a Sigmoid function\n",
    "\n",
    "        Metrics to optimize:\n",
    "            'tpr' : tp/(tp+fp)\n",
    "            'm1' : tp/(tp+fn+fp)\n",
    "            'm2': tp/(tp+fn+sfp) #SFP are filtered false positives for the EWS\n",
    "        '''\n",
    "        \n",
    "        if verbose: print('fitting')\n",
    "        \n",
    "        max_n_predictors = self.params['max_n_predictors']\n",
    "        threshold_metric = self.params['threshold_metric']\n",
    "        metrics_to_optimize = self.params['metrics_to_optimize']\n",
    "        lookback_window = self.params['lookback_window']\n",
    "        metric_to_optimize_beta = self.params['metric_to_optimize_beta']\n",
    "        weights = []\n",
    "        total_n_predictors = X_training.shape[1]\n",
    "        predictor_individual_performance = {\n",
    "            'name_predictor':[],\n",
    "            'tp':[],\n",
    "            'fp':[],\n",
    "            'sfp':[],\n",
    "            'tn':[],\n",
    "            'fn':[],\n",
    "            'tpr':[],\n",
    "            'precision':[],\n",
    "            'm1':[],\n",
    "            'm2':[],\n",
    "            'm3':[],\n",
    "        }\n",
    "        \n",
    "\n",
    "        dict_metric_sort_ascending = {\n",
    "            'name_predictor':[],\n",
    "            'tp':False,\n",
    "            'fp':True,\n",
    "            'sfp':True,\n",
    "            'tn':False,\n",
    "            'fn':True,\n",
    "            'tpr':False,\n",
    "            'precision':False,\n",
    "            'm1':False,\n",
    "            'm2':False,\n",
    "            'm3':False,\n",
    "        }\n",
    "\n",
    "        sort_ascending = []\n",
    "        for metric in metrics_to_optimize:\n",
    "            sort_ascending.append(dict_metric_sort_ascending[metric])\n",
    "\n",
    "        for name_predictor in list(X_training):\n",
    "\n",
    "            #---- Merge the predictor events alongside the target events\n",
    "            df_to_benchmark = pd.concat([ X_training[[name_predictor]].copy(), Y_training.copy()], axis=1)\n",
    "            df_to_benchmark.columns = ['x', 'y']\n",
    "\n",
    "            events_labeled = label_events(df_to_benchmark,\n",
    "                           lookback_window=lookback_window,\n",
    "                            )\n",
    "            \n",
    "            events_labeled['sfp'] = []\n",
    "            [tp, fp, tn, fn, sfp, tpr, precision, m1, m2, m3, earliness] = events2metrics(events_labeled)\n",
    "            \n",
    "            predictor_individual_performance['name_predictor'].append(name_predictor)\n",
    "            predictor_individual_performance['tp'].append(tp)\n",
    "            predictor_individual_performance['fp'].append(fp)\n",
    "            predictor_individual_performance['sfp'].append(sfp)\n",
    "            predictor_individual_performance['tn'].append(tn)\n",
    "            predictor_individual_performance['fn'].append(fn)\n",
    "            predictor_individual_performance['tpr'].append(tpr)\n",
    "            predictor_individual_performance['precision'].append(precision)\n",
    "            predictor_individual_performance['m1'].append(m1)\n",
    "            predictor_individual_performance['m2'].append(m2)\n",
    "            predictor_individual_performance['m3'].append(m3)\n",
    "\n",
    "        predictor_individual_performance = pd.DataFrame(predictor_individual_performance)\n",
    "        predictor_individual_performance.set_index('name_predictor', inplace=True)  \n",
    "        if verbose: print('fitting phase 2')\n",
    "        #---- PREDICTOR SELECTION + Decision Threshold \n",
    "        \n",
    "        # VERSION 1 \n",
    "        #-------- Pick the top-n based on optimizing metrics such as TPR\n",
    "        predictor_individual_performance.sort_values(metrics_to_optimize, ascending=sort_ascending, inplace=True)\n",
    "        \n",
    "        n_predictors_good_fit = predictor_individual_performance[predictor_individual_performance['tpr'] >= threshold_metric].shape[0]\n",
    "        \n",
    "        if n_predictors_good_fit == 0:\n",
    "            #print('No predictor found to have a value above the {0} threshold for {1}'.format(threshold_metric, metrics_to_optimize))\n",
    "            return\n",
    "        \n",
    "        if max_n_predictors is not None:\n",
    "            \n",
    "            if total_n_predictors < max_n_predictors:\n",
    "                warnings.warn('Max number of predictors is => {0} but training data only contains {1} predictors.'.format(max_n_predictors, total_n_predictors))\n",
    "                n_predictors = total_n_predictors\n",
    "            else:\n",
    "                n_predictors = max_n_predictors\n",
    "            \n",
    "        else:\n",
    "            n_predictors = predictor_individual_performance[predictor_individual_performance['tpr'] >= threshold_metric].shape[0]\n",
    "\n",
    "                \n",
    "\n",
    "        \n",
    "                \n",
    "                \n",
    "        \n",
    "        self.weights = dict(zip(list(predictor_individual_performance.index.values), [1]*n_predictors + [0]*(total_n_predictors-n_predictors) ))\n",
    "        self.individual_proxy_scores = predictor_individual_performance\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        #-------- Select threshold based on highest m1 (max tpr and min fp)\n",
    "        if verbose: print('fitting phase 3')\n",
    "        \n",
    "        n_non_zero_weight_predictors = (np.array(list(self.weights.values())) > 0).sum()\n",
    "        \n",
    "        \n",
    "        #Getting EWS timeseries using predict\n",
    "        ews_predictions = self.predict(X_training, beta=0)[['ews']]\n",
    "        \n",
    "        if n_non_zero_weight_predictors > 1:\n",
    "            \n",
    "            df_to_benchmark = pd.concat([ews_predictions, Y_training], axis=1)\n",
    "            df_to_benchmark.columns = ['x','y']\n",
    "            \n",
    "            \n",
    "            roc_data = generate_ROC_dataFrame(ews_predictions, Y_training, n_non_zero_weight_predictors)\n",
    "            roc_data.sort_values(metric_to_optimize_beta, inplace=True, ascending=dict_metric_sort_ascending[metric_to_optimize_beta])\n",
    "            self.roc_data = roc_data\n",
    "            \n",
    "            if verbose:\n",
    "                print('roc data => ')\n",
    "                print(roc_data)\n",
    "                print('beta selected =>', roc_data.index.values[0])\n",
    "                \n",
    "            self.beta = roc_data.index.values[0]\n",
    "            \n",
    "            \n",
    "            \n",
    "        else:\n",
    "            # If there is only one predictor, the EWS timeseries can only output .45 at most\n",
    "            self.beta = .45\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        df_to_benchmark = pd.concat([ews_predictions > self.beta, Y_training], axis=1)\n",
    "        df_to_benchmark.columns = ['x','y']\n",
    "        self.events_training_current_beta = label_events(df_to_benchmark,\n",
    "                       lookback_window=4,\n",
    "                     verbose=False,\n",
    "                       )\n",
    "        \n",
    "        if verbose: print('fitting phase 4')\n",
    "        \n",
    "        return\n",
    "    \n",
    "    \n",
    "    def predict(self,X_predict,sma_moving_window=3, beta=None):\n",
    "        '''\n",
    "            Prediction function for EWS.\n",
    "            Scales the timeseries based on the weights obtained through the fit function\n",
    "            merges them (aggregate), applies a moving average over the aggregated timeseries\n",
    "            and finally maps the timeseries using a sigmoid function with a domain alteration.\n",
    "\n",
    "            INPUT\n",
    "            _____\n",
    "\n",
    "            X_predict : Pandas DataFrame\n",
    "                Dataset containing the events of each timeseries.\n",
    "\n",
    "            sma_moving_window : int (default 3)\n",
    "                Number of points to look for events retrospectively at time t.\n",
    "\n",
    "\n",
    "            OUTPUT\n",
    "            ______\n",
    "\n",
    "            ews_timeseries : Np.array\n",
    "                Each value represents\n",
    "\n",
    "        '''\n",
    "        \n",
    "        weights = self.weights\n",
    "        if beta is None:\n",
    "            beta = self.beta\n",
    "        # Sum events together\n",
    "        merged_timeseries = np.zeros(X_predict.shape[0])\n",
    "        for col_name, weight in weights.items():\n",
    "            if weight != 0:\n",
    "                try:\n",
    "                    v = X_predict[col_name].values*weight\n",
    "                    v[v<0]=0 #Removing end of events (which are marked as -1)\n",
    "                    merged_timeseries += v\n",
    "                except Exception as t:\n",
    "                    logging.Exception('Missing predict : {0}'.format(col_name))\n",
    "                    print('Missing predict : {0}'.format(col_name))\n",
    "\n",
    "        # Use restrospective moving window to sum the past activities\n",
    "        ews_timeseries = np.zeros_like(merged_timeseries)\n",
    "        for i in range(sma_moving_window, len(ews_timeseries)):\n",
    "            ews_timeseries[i] = mapped_sigmoid(merged_timeseries[i-sma_moving_window+1:i+1].sum())\n",
    "        ews_timeseries[:sma_moving_window] = float('nan')\n",
    "        ews_timeseries = pd.DataFrame({'ews':list(ews_timeseries), 'date':list(X_predict.index.values) })\n",
    "        ews_timeseries.set_index('date', inplace=True)\n",
    "        #print(ews_timeseries)\n",
    "        #print(self.beta)\n",
    "        ews_timeseries['ews_thresholded']  = ews_timeseries['ews'].values >= beta\n",
    "        return ews_timeseries\n",
    "    \n",
    "class EWS_Discrete_Ensemble:\n",
    "    def __init__(self,\n",
    "                max_n_predictors = 6,\n",
    "                threshold_metric = .7,\n",
    "                 lookback_window = 6,\n",
    "                metrics_to_optimize = ['tp', 'fp'],\n",
    "                metric_to_optimize_beta = 'm2',\n",
    "                 sma_moving_window=3,\n",
    "                n_ews = None,\n",
    "                seed = 2,\n",
    "                split = False,\n",
    "                 max_n_ews = float('inf'),\n",
    "                ):\n",
    "        '''\n",
    "            INPUT\n",
    "            _____\n",
    "            \n",
    "            max_n_predictors : int (default 6)\n",
    "                Sets a threshold to the number of input variables to use for the ews.\n",
    "            \n",
    "            threshold_metric : float (default .6)\n",
    "                After sorting, filters predictors that meet the threshold criteria for\n",
    "                m1\n",
    "            \n",
    "            lookback_window : int (default 4)\n",
    "                Earliness to look for events\n",
    "            \n",
    "            metrics_to_optimize = List of str\n",
    "                Available metrics => tpr, precision, tp, fp, sfp, fn, m1\n",
    "                'tpr' : tp/(tp+fp)\n",
    "                'm1' : tp/(tp+fn+fp)\n",
    "                'm2' : tp/(tp+fn+sfp)\n",
    "                'm3' : ( m1 + m2 )/2\n",
    "                \n",
    "            sma_moving_window : int (default 3)\n",
    "                Number of points to look for events retrospectively at time t.\n",
    "                \n",
    "            split : boolean (default false)\n",
    "                If True, then splits training dataset in an in_sample dataset (n-1 events) and out_of_sample (1 event)\n",
    "                to find the threshold\n",
    "        '''\n",
    "        \n",
    "        self.weights = []\n",
    "        self.timeseries_selected = []\n",
    "        self.individual_proxy_scores = []\n",
    "        self.params = {\n",
    "            'max_n_predictors' : max_n_predictors,\n",
    "            'threshold_metric' : threshold_metric,\n",
    "            'metrics_to_optimize' :metrics_to_optimize,\n",
    "            'lookback_window': lookback_window,\n",
    "            'metric_to_optimize_beta':metric_to_optimize_beta,\n",
    "            'seed':2,\n",
    "            'split':split,\n",
    "            'max_n_ews':max_n_ews\n",
    "        }\n",
    "        self.models = None\n",
    "        self.n_ews = None\n",
    "        self.sma_moving_window=sma_moving_window\n",
    "        self.voting_threshold = None\n",
    "        #print('EWS for discrete activations created')\n",
    "    \n",
    "    def fit(self,\n",
    "                X_training,\n",
    "                Y_training,\n",
    "                 ):\n",
    "        ''' Training function V1 for EWS_discrete.\n",
    "        1.-Computes TP, FP, FN for each event\n",
    "        2.-Selects top n predictors\n",
    "        3.-Squashes them into a probability using a Sigmoid function\n",
    "\n",
    "        Metrics to optimize:\n",
    "            'tpr' : tp/(tp+fp)\n",
    "            'm1' : tp/(tp+fn+fp)\n",
    "            'm2': tp/(tp+fn+sfp) #SFP are filtered false positives for the EWS\n",
    "        '''\n",
    "        \n",
    "        max_n_predictors = self.params['max_n_predictors']\n",
    "        threshold_metric = self.params['threshold_metric']\n",
    "        metrics_to_optimize = self.params['metrics_to_optimize']\n",
    "        lookback_window = self.params['lookback_window']\n",
    "        metric_to_optimize_beta = self.params['metric_to_optimize_beta']\n",
    "        n_ews = self.n_ews\n",
    "        split = self.params['split']\n",
    "        max_n_ews = self.params['max_n_ews']\n",
    "        \n",
    "        \n",
    "        if split:    \n",
    "            X_in_sample, X_out_of_sample, Y_in_sample, Y_out_of_sample = split_training_dataset(X_training, Y_training)\n",
    "        else:\n",
    "            X_in_sample = X_training\n",
    "            X_out_of_sample = X_training\n",
    "            Y_in_sample = Y_training\n",
    "            Y_out_of_sample = Y_training\n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        weights = []\n",
    "        total_n_predictors = X_training.shape[1]\n",
    "        predictor_individual_performance = {\n",
    "            'name_predictor':[],\n",
    "            'tp':[],\n",
    "            'fp':[],\n",
    "            'sfp':[],\n",
    "            'tn':[],\n",
    "            'fn':[],\n",
    "            'tpr':[],\n",
    "            'precision':[],\n",
    "            'm1':[],\n",
    "            'm2':[],\n",
    "            'm3':[],\n",
    "        }\n",
    "\n",
    "        dict_metric_sort_ascending = {\n",
    "            'name_predictor':[],\n",
    "            'tp':False,\n",
    "            'fp':True,\n",
    "            'sfp':True,\n",
    "            'tn':False,\n",
    "            'fn':True,\n",
    "            'tpr':False,\n",
    "            'precision':False,\n",
    "            'm1':False,\n",
    "            'm2':False,\n",
    "            'm3':False,\n",
    "        }\n",
    "\n",
    "        sort_ascending = []\n",
    "        for metric in metrics_to_optimize:\n",
    "            sort_ascending.append(dict_metric_sort_ascending[metric])\n",
    "\n",
    "        for name_predictor in list(X_training):\n",
    "\n",
    "            #---- Merge the predictor events alongside the target events\n",
    "            df_to_benchmark = pd.concat([ X_in_sample[[name_predictor]].copy(), Y_in_sample.copy()], axis=1)\n",
    "            df_to_benchmark.columns = ['x', 'y']\n",
    "\n",
    "            events_labeled = label_events(df_to_benchmark,\n",
    "                           lookback_window=lookback_window,\n",
    "                            )\n",
    "            \n",
    "            events_labeled['sfp'] = []\n",
    "            [tp, fp, tn, fn, sfp, tpr, precision, m1, m2, m3, earliness] = events2metrics(events_labeled)\n",
    "            \n",
    "            predictor_individual_performance['name_predictor'].append(name_predictor)\n",
    "            predictor_individual_performance['tp'].append(tp)\n",
    "            predictor_individual_performance['fp'].append(fp)\n",
    "            predictor_individual_performance['sfp'].append(sfp)\n",
    "            predictor_individual_performance['tn'].append(tn)\n",
    "            predictor_individual_performance['fn'].append(fn)\n",
    "            predictor_individual_performance['tpr'].append(tpr)\n",
    "            predictor_individual_performance['precision'].append(precision)\n",
    "            predictor_individual_performance['m1'].append(m1)\n",
    "            predictor_individual_performance['m2'].append(m2)\n",
    "            predictor_individual_performance['m3'].append(m3)\n",
    "\n",
    "        \n",
    "        predictor_individual_performance = pd.DataFrame(predictor_individual_performance)\n",
    "        predictor_individual_performance.set_index('name_predictor', inplace=True)  \n",
    " \n",
    "        #---- PREDICTOR SELECTION + Decision Threshold \n",
    "        \n",
    "        # VERSION 1 \n",
    "        #-------- Pick the top-n based on optimizing metrics such as TPR\n",
    "        predictor_individual_performance.sort_values(metrics_to_optimize, ascending=sort_ascending, inplace=True)\n",
    "        if max_n_predictors is not None:\n",
    "            if total_n_predictors < max_n_predictors:\n",
    "                warnings.warn('Max number of predictors is => {0} but training data only contains {1} predictors.'.format(max_n_predictors, total_n_predictors))\n",
    "                n_predictors = total_n_predictors\n",
    "            else:\n",
    "                n_predictors = max_n_predictors\n",
    "        else:\n",
    "            n_predictors = predictor_individual_performance[predictor_individual_performance['tpr'] >= threshold_metric].shape[0]\n",
    "\n",
    "            if n_predictors == 0:\n",
    "                print('No predictor found to have a value above the {0} threshold for {1}'.format(threshold_metric, metrics_to_optimize))\n",
    "\n",
    "        self.individual_proxy_scores = predictor_individual_performance\n",
    "        proxies_good_score = predictor_individual_performance[predictor_individual_performance['m1'] > .6]\n",
    "        n_proxies_good_score =  proxies_good_score.shape[0]\n",
    "        #print('N proxies with {0} score above {1} = {2}'.format(metric_to_optimize_beta, threshold_metric,n_proxies_good_score))\n",
    "        \n",
    "        \n",
    "        if n_ews is None:\n",
    "            n_ews = int(n_proxies_good_score/n_predictors)\n",
    "        elif n_ews > int(n_proxies_good_score/n_predictors):\n",
    "            warnings.warn('Number of ews to fit is more than ratio between number proxies good score / number of predictors within each ews. Thus each EWS may show big similarity')\n",
    "        \n",
    "        if n_ews < 1:\n",
    "            print('Not enough predictors to generate ews (n_predictors = {0}, n_proxies with good score ={1})'.format(n_predictors, n_proxies_good_score))\n",
    "            return \n",
    "        \n",
    "        if n_ews > max_n_ews:\n",
    "            n_ews = max_n_ews\n",
    "        \n",
    "        \n",
    "        # Each model must have information about\n",
    "        # the selected proxies (names)\n",
    "        # the beta score\n",
    "        #\n",
    "        models = []\n",
    "        # Subsetting dataset at random\n",
    "        np.random.seed(self.params['seed'])\n",
    "        \n",
    "        for n_model in range(n_ews):\n",
    "            \n",
    "            \n",
    "            model_data = {}\n",
    "            selected_predictors = np.random.choice(list(range(n_proxies_good_score)), n_predictors)\n",
    "            weights = {}\n",
    "            for predictor_name in proxies_good_score.index.values[selected_predictors]:\n",
    "                weights[predictor_name] = 1\n",
    "                 \n",
    "            model_data['weights'] = weights\n",
    "            n_non_zero_weight_predictors = (np.array(list(weights.values())) > 0).sum()\n",
    "            ews_predictions = self.predict_single(X_in_sample, weights)\n",
    "            #print(ews_predictions)\n",
    "            df_to_benchmark = pd.concat([ews_predictions, Y_in_sample], axis=1)\n",
    "            df_to_benchmark.columns = ['x','y']\n",
    "\n",
    "            roc_data = generate_ROC_dataFrame(ews_predictions, Y_in_sample, n_non_zero_weight_predictors)\n",
    "            model_data['roc_data'] = roc_data\n",
    "\n",
    "            roc_data.sort_values(metric_to_optimize_beta, inplace=True, ascending=dict_metric_sort_ascending[metric_to_optimize_beta])\n",
    "            model_data['beta'] = roc_data.index.values[0]\n",
    "            \n",
    "            #print(' \\n \\n Model {0} Fit => {1}'.format(n_model, model_data))\n",
    "            #print(df_to_benchmark)\n",
    "            models.append(model_data)\n",
    "        \n",
    "        figur\n",
    "        self.models = models\n",
    "        # TO DO, if model is not good out of sample\n",
    "        \n",
    "\n",
    "        #---- Finding voting threshold\n",
    "        n_models = len(self.models)\n",
    "        ews_timeseries = self.predict(X_out_of_sample)\n",
    "        thresholds = list((range(n_models)))\n",
    "        roc_data_ensemble = generate_ROC_dataFrame(ews_timeseries[['ews']], Y_out_of_sample, ensemble_thresholds = thresholds)\n",
    "        self.roc_data = roc_data_ensemble\n",
    "        roc_data_ensemble.sort_values(metric_to_optimize_beta, inplace=True, ascending=dict_metric_sort_ascending[metric_to_optimize_beta])\n",
    "        self.voting_threshold = roc_data_ensemble.index.values[0] \n",
    "            \n",
    "            \n",
    "        \n",
    "            \n",
    "        return\n",
    "    \n",
    "    \n",
    "    def predict(self, X_predict):\n",
    "        \n",
    "        '''\n",
    "            Wrapper function. Grabs each model separately and predicts output.\n",
    "            Decision is taken as a voting ensemble\n",
    "        '''\n",
    "        \n",
    "        models = self.models\n",
    "        n_models = len(models)\n",
    "        model_individual_predictions = []\n",
    "        dates = X_predict.index.values.ravel()\n",
    "        if self.voting_threshold is None:\n",
    "            voting_threshold = n_models/2\n",
    "        else:\n",
    "            voting_threshold = self.voting_threshold\n",
    "\n",
    "        \n",
    "        # Subsetting dataset at random\n",
    "        for n_model in range(n_models):\n",
    "            model_data = models[n_model]\n",
    "            weights = model_data['weights']\n",
    "            beta = model_data['beta']\n",
    "            n_non_zero_weight_predictors = (np.array(list(weights.values())) > 0).sum()\n",
    "            \n",
    "            ews_predictions = self.predict_single(X_predict, weights)\n",
    "            ews_predictions.columns = [str(n_model)]\n",
    "            ews_predictions['{0}_thresholded'.format(n_model)] = ews_predictions.values.ravel() >= beta\n",
    "            model_individual_predictions.append(ews_predictions.copy())\n",
    "\n",
    "        model_individual_predictions = pd.concat( model_individual_predictions, axis=1)\n",
    "        model_individual_predictions['ews'] = model_individual_predictions[[str(n_model) + '_thresholded' for n_model in range(n_models)]].sum(axis=1).values\n",
    "        model_individual_predictions['ews_thresholded'] = model_individual_predictions['ews'] >= voting_threshold\n",
    "        \n",
    "        return model_individual_predictions\n",
    "    \n",
    "    def predict_single(self,X_predict, weights):\n",
    "        '''\n",
    "            Prediction function for EWS.\n",
    "            Scales the timeseries based on the weights obtained through the fit function\n",
    "            merges them (aggregate), applies a moving average over the aggregated timeseries\n",
    "            and finally maps the timeseries using a sigmoid function with a domain alteration.\n",
    "\n",
    "            INPUT\n",
    "            _____\n",
    "\n",
    "            X_predict : Pandas DataFrame\n",
    "                Dataset containing the events of each timeseries.\n",
    "\n",
    "            \n",
    "\n",
    "\n",
    "            OUTPUT\n",
    "            ______\n",
    "\n",
    "            ews_timeseries : Np.array\n",
    "                Each value represents\n",
    "\n",
    "        '''\n",
    "        \n",
    "        # Sum events together\n",
    "        sma_moving_window = self.sma_moving_window\n",
    "        merged_timeseries = np.zeros(X_predict.shape[0])\n",
    "        for col_name, weight in weights.items():\n",
    "            if weight != 0:\n",
    "                try:\n",
    "                    v = X_predict[col_name].values*weight\n",
    "                    v[v<0]=0 #Removing end of events (which are marked as -1)\n",
    "                    merged_timeseries += v\n",
    "                except Exception as t:\n",
    "                    logging.Exception('Missing predict : {0}'.format(col_name))\n",
    "                    print('Missing predict : {0}'.format(col_name))\n",
    "\n",
    "        # Use restrospective moving window to sum the past activities\n",
    "        ews_timeseries = np.zeros_like(merged_timeseries)\n",
    "        for i in range(sma_moving_window, len(ews_timeseries)):\n",
    "            ews_timeseries[i] = mapped_sigmoid(merged_timeseries[i-sma_moving_window+1:i+1].sum())\n",
    "        ews_timeseries[:sma_moving_window] = float('nan')\n",
    "        ews_timeseries = pd.DataFrame({'ews_discrete':list(ews_timeseries), 'date':list(X_predict.index.values) })\n",
    "        ews_timeseries.set_index('date', inplace=True)\n",
    "        return ews_timeseries\n",
    "\n",
    "def split_training_dataset(X_training, Y_training):\n",
    "    '''\n",
    "        Helper function that splits n events into n-1 and 1 event datasets\n",
    "    '''\n",
    "    events = Y_training[Y_training[name_target] == 1 ]\n",
    "    dates = list(Y_training.index.values)\n",
    "    if  events.shape[0] > 1:\n",
    "        index_split = dates.index(events.index.values[-2])\n",
    "        X_in_sample = X_training.iloc[:index_split+2]\n",
    "        X_out_of_sample = X_training.iloc[index_split+1:]\n",
    "        Y_in_sample = Y_training.iloc[:index_split+2]\n",
    "        Y_out_of_sample = Y_training.iloc[index_split+1:]\n",
    "    return X_in_sample, X_out_of_sample, Y_in_sample, Y_out_of_sample\n",
    "\n",
    "\n",
    "        \n",
    "# Populate events\n",
    "\n",
    "def add_new_events_to_df(df_events, new_events, name_timeseries):\n",
    "    '''FUNCTION add_new_events_to_df\n",
    "    '''\n",
    "    \n",
    "    try:\n",
    "        df_events.drop(df_events[df_events['name_timeseries'] == name_timeseries ].index.values, axis=0, inplace=True)\n",
    "    except Exception as t:\n",
    "        print(t)\n",
    "        pass\n",
    "\n",
    "    df_events = pd.concat([df_events, new_events], axis=0).reset_index()[['site_name', 'n_event', 'date_start', 'date_end', 'name_timeseries']]\n",
    "    \n",
    "    return df_events\n",
    "    '''ENFUNCTION'''\n",
    "\n",
    "\n",
    "def find_chains(event_list, gap_tolerance = 7, min_n_events=2):\n",
    "    '''\n",
    "        Looks for continuous true events within a list of True/False and generates a list of subsets\n",
    "        of such events\n",
    "        \n",
    "        Parameters\n",
    "        __________\n",
    "        \n",
    "        indices_list : List of Boolean\n",
    "        \n",
    "        gap_tolerance : int\n",
    "            considers events within a chain even after gap_tolerance number of samples being False.\n",
    "            \n",
    "            \n",
    "        min_n_events : int\n",
    "            \n",
    "    \n",
    "    '''\n",
    "    \n",
    "    indices_list = [i for i, val in enumerate(event_list) if val == True]\n",
    "    chains = []\n",
    "    chain = []\n",
    "    pointer = 0\n",
    "    \n",
    "    while pointer < len(indices_list)-1:\n",
    "        \n",
    "        chain = [pointer]\n",
    "        subsequent = True\n",
    "        while subsequent and pointer < len(indices_list)-1:\n",
    "            \n",
    "            #print('indices_list[pointer+1]=>', indices_list[pointer+1])\n",
    "            #print('indices_list[chain[-1]]=>', indices_list[chain[-1]])\n",
    "            #print('indices_list[pointer+1] - indices_list[chain[-1]] < gap_tolerance =>', indices_list[pointer+1] - indices_list[chain[-1]] < gap_tolerance)\n",
    "            \n",
    "            index_distance = indices_list[pointer+1] - indices_list[chain[-1]]\n",
    "            \n",
    "            if index_distance == 0:\n",
    "                return('Error: Repeated Index')\n",
    "            \n",
    "            if index_distance == 1: # Subsequent events\n",
    "                chain.append(pointer+1)\n",
    "            elif indices_list[pointer+1] - indices_list[chain[-1]] < gap_tolerance:\n",
    "                chain.append(pointer+1)\n",
    "            else:\n",
    "                subsequent = False\n",
    "                if len(chain) >= min_n_events:\n",
    "                    chains.append(chain)\n",
    "                chain = []\n",
    "            pointer +=1\n",
    "    \n",
    "    \n",
    "    if len(chain) >= min_n_events:\n",
    "        chains.append(chain)\n",
    "    \n",
    "    \n",
    "    timeseries_indices =[]\n",
    "    for chain in chains:\n",
    "        timeseries_indices.append( list( np.array(indices_list)[np.array(chain)] )  )\n",
    "    \n",
    "    return timeseries_indices\n",
    "    \n",
    "# Compute events\n",
    "def get_event_from_lambda_timeseries(df_lambda, name_signal, threshold = 1, gap_tolerance = 2, verbose=False):\n",
    "    '''FUNCTION get_event_from_lambda_timeseries \n",
    "    '''\n",
    "\n",
    "    site_names = []\n",
    "    n_event = []\n",
    "    date_start = []\n",
    "    date_end = []\n",
    "\n",
    "    #Getting Events\n",
    "    dates = list(df_lambda.index.values)\n",
    "    chains = find_chains(list(df_lambda.values.ravel() > threshold) , gap_tolerance= gap_tolerance)\n",
    "    \n",
    "\n",
    "    \n",
    "    for i, chain in enumerate(chains):\n",
    "        date_start.append(dates[chain[0]])\n",
    "        date_end.append(dates[chain[-1]])\n",
    "        site_names.append(site_name)\n",
    "        n_event.append(i)\n",
    "    df = pd.DataFrame({\n",
    "        'site_name' : site_names,\n",
    "        'n_event' : n_event,\n",
    "        'date_start':date_start,\n",
    "        'date_end':date_end,\n",
    "        'name_timeseries':name_signal\n",
    "    })      \n",
    "    \n",
    "    '''ENDFUNCTION'''\n",
    "    return df   \n",
    "\n",
    "\n",
    "# function generate timeseries file (dates)\n",
    "def generate_timeseries_file(dates, values_timeseries, name_timeseries):\n",
    "    '''FUNCTION generate_timeseries_file'''\n",
    "    df = pd.DataFrame({'date':dates, name_timeseries: values_timeseries})\n",
    "    df.set_index('date', inplace=True)\n",
    "    '''ENDFUNCTION'''\n",
    "    return  df\n",
    "\n",
    "# function generate event file\n",
    "def generate_event_file():\n",
    "    '''FUNCTION generate_timeseries_file'''\n",
    "    '''ENDFUNCTION'''\n",
    "    return pd.DataFrame({'site_name':[], 'n_event':[],'date_start':[], 'date_end':[], 'name_timeseries':[]})\n",
    "    \n",
    "    \n",
    "def load_timeseries_file(path_file):\n",
    "    '''FUNCTION load_timeseries_file'''\n",
    "    '''ENDFUNCTION'''\n",
    "    return pd.read_csv(path_file, index_col=0)\n",
    "\n",
    "\n",
    "# load_event_file():\n",
    "def load_event_file(path_file):\n",
    "    '''FUNCTION load_timeseries_file'''\n",
    "    '''ENDFUNCTION'''\n",
    "    return pd.read_csv(path_file, index_col=0)\n",
    "\n",
    "\n",
    "\n",
    "def merge_datasets(df1, df2, rewrite_columns=False, suffix = ''):\n",
    "    '''FUNCTION merge_datasets \n",
    "        merges two dataframes (from df2 to df1), adds a suffix to columns from df2 beforehand\n",
    "        if rewrite_columns, then columns with same name are dropped from df1\n",
    "    '''\n",
    "    col_names_df1 = list(df1)\n",
    "    \n",
    "    if suffix == '':\n",
    "        col_names_df2 = df2.columns\n",
    "        \n",
    "    else:\n",
    "        col_names_df2 = [suffix + '_'+ name for name in df2.columns]\n",
    "        df2.columns = col_names_df2\n",
    "        \n",
    "    \n",
    "    if rewrite_columns:\n",
    "        for col_name in col_names_df2:\n",
    "            if col_name in col_names_df1:\n",
    "                df1.drop(col_name, axis=1, inplace=True)\n",
    "    else:\n",
    "        for col_name in col_names_df2:\n",
    "            if col_name in col_names_df1:\n",
    "                raise Exception('Columns already exist in df1')\n",
    "    \n",
    "    df1 = pd.concat([df1,df2], axis=1)\n",
    "    return df1    \n",
    "    '''ENDFUNCTION'''\n",
    "\n",
    "\n",
    "def add_timeseries_to_df(df1,df2,col_df1,col_df2, index_difference_tolerance=False, overwrite=False, verbose=False):\n",
    "    '''FUNCTION add_timeseries_to_df\n",
    "        \n",
    "        gets timeseries from df1 and adds to df2.\n",
    "        restrictions: Date format on index is YYYY-MM-DD\n",
    "        \n",
    "        \n",
    "        PARAMETERS:\n",
    "        \n",
    "        col_df1 : str\n",
    "            Column name from df1 to copy.\n",
    "        \n",
    "        col_df2 : str\n",
    "            Column name from df2 to write to\n",
    "        \n",
    "        index_difference_tolerance : Either a float (0, to 1) or False\n",
    "            If False, merges DF completely.\n",
    "            If float value, then checks that the number of different\n",
    "            indices from df1 to df2 is no different than len(df2.index.values). \n",
    "    '''\n",
    "    n_different_indices = len([ind for ind in df1.index.values if ind not in df2.index.values])\n",
    "    exists_within_df2 = col_df2 in list(df2)\n",
    "    \n",
    "    if verbose:\n",
    "        print('Different indices =>', n_different_indices)\n",
    "        print('exists within target df? =>', exists_within_df2)\n",
    "    \n",
    "    # If not rewrite and exists, then get out\n",
    "    if not overwrite and exists_within_df2:\n",
    "        raise Exception('Column name already exists in df2')\n",
    "    else:\n",
    "        \n",
    "        # Check difference in tolerance\n",
    "        if index_difference_tolerance:\n",
    "            if n_different_indices > index_difference_tolerance*df.index.shape[0]:\n",
    "                raise Exception('Too many indices are different between df1 and df2')\n",
    "        \n",
    "        if overwrite and exists_within_df2:\n",
    "            df2.drop(col_df2, axis=1, inplace=True)\n",
    "        \n",
    "        sub_df = df1[[col_df1]]\n",
    "        sub_df.columns = [col_df2]\n",
    "        df2 = pd.concat([df2, sub_df], axis=1)    \n",
    "    \n",
    "    return df2\n",
    "    \n",
    "#add_event_to_df():\n",
    "# compute lambda timeseries\n",
    "# All timeseries\n",
    "# Load timeseries lambda\n",
    "# Load timeseries binary\n",
    "# Load events\n",
    "\n",
    "def generate_binary_timeseries(df_timeseries, df_events, name_signal):\n",
    "    '''FUNCTION generate_binary_timeseries'''\n",
    "    \n",
    "    timeseries_binary = df_timeseries.copy()*0\n",
    "    \n",
    "    \n",
    "    for i, event_info in df_events.iterrows():\n",
    "        \n",
    "        timeseries_binary.loc[event_info['date_start'], name_signal] = 1\n",
    "        timeseries_binary.loc[event_info['date_end'], name_signal] = -1\n",
    "        \n",
    "    '''ENDFUNCTION'''    \n",
    "    return timeseries_binary \n",
    "\n",
    "\n",
    "  \n",
    "def event_analysis_singlesource(events,\n",
    "                   event_timeseries,\n",
    "                   voting_threshold,\n",
    "                   lookback_window,\n",
    "                   late_window,\n",
    "                   percent_to_threshold,\n",
    "                   verbose=False,\n",
    "                         threshold = 1\n",
    "                  ):\n",
    "    '''\n",
    "    As a way to quantify the effectiveness of our methodology, and similar to our alternative approach,\n",
    "    we defined the following categories:\n",
    "    Early Warning: An EWS activation occurs earlier in time in comparison to our target (at most 6 weeks earlier).\n",
    "    Synchronous Activation: An EWS activation occurs at the same date that a target event is identified.\n",
    "    Late Activation: An EWS activation is registered after a target event is identified (at most 2 weeks later)\n",
    "    Activity Increase but no Event: An EWS event is captured and a subsequent increase in target activity is also observed,\n",
    "    but not enough to be considered an epidemic event (same time window as Early Warning and Late Activation).\n",
    "    Missed event, EWS close to activation: A target event was detected \n",
    "    Missed event, no EWS activity: A target event was detected but no EWS activation was observed\n",
    "    False Alarm: An EWS activation is observed, but no event or activity increase is detected within the target.\n",
    "    '''\n",
    "    \n",
    "    if verbose:\n",
    "        print('Event =>', events)\n",
    "        print('Event_timeseries =>', event_timeseries)\n",
    "\n",
    "    # True Positive\n",
    "    n_early = 0\n",
    "    n_sync = 0\n",
    "    n_late = 0\n",
    "\n",
    "    # FP\n",
    "    n_warning_increase_but_no_outbreak = 0\n",
    "    n_false_alarm = 0\n",
    "\n",
    "    # FN\n",
    "    n_missed = 0\n",
    "    n_missed_close_threshold = 0\n",
    "\n",
    "    # For each  event in the labeled events\n",
    "    # If TP and > 0, then early_warning++\n",
    "    # If TP and == 0, then synchronous_activation++\n",
    "    # If TP and < 0, then late_activation++\n",
    "    \n",
    "    event_date_labels = []\n",
    "    \n",
    "    if verbose: print('\\n\\n\\n Analyzing True Positives')\n",
    "    for date_target_activation, date_predictor_activation, n_weeks  in events['tp']:\n",
    "        ind_tp = event_timeseries.index.get_loc(date_target_activation)\n",
    "        if verbose:\n",
    "            print('target act => ', date_target_activation)\n",
    "            print('predictor act => ', date_predictor_activation)\n",
    "            print('n_weeks_diff =>', n_weeks)\n",
    "        if n_weeks > 0:\n",
    "            n_early += 1\n",
    "            if verbose: print('Adding early activation')\n",
    "            event_date_labels.append(['early',date_target_activation,date_predictor_activation])\n",
    "        elif n_weeks == 0:\n",
    "            n_sync += 1\n",
    "            if verbose: print('Adding sync activation')\n",
    "            event_date_labels.append(['sync',date_target_activation,date_predictor_activation])\n",
    "        elif n_weeks < 0:\n",
    "            n_late += 1\n",
    "            if verbose: print('Adding late activation')\n",
    "            event_date_labels.append(['late',date_target_activation,date_predictor_activation])\n",
    "\n",
    "    # if FN, then missed_event++\n",
    "    if verbose: print('\\n\\n\\n Analyzing False Negatives')\n",
    "\n",
    "    for date_target_activation  in events['fn']:\n",
    "        ind_tp = event_timeseries.index.get_loc(date_target_activation)\n",
    "        \n",
    "        if verbose:\n",
    "                print('False Negative classified as missed event without relevant EWS activity')\n",
    "        n_missed +=1\n",
    "        event_date_labels.append(['missed',date_target_activation,''])\n",
    "        \n",
    "        \n",
    "        ''' Need lucas to give a timeseries for this\n",
    "        if verbose: print('\\n\\nDate Target Event =>', date_target_activation)\n",
    "        ind = event_timeseries.index.get_loc(date_target_activation)\n",
    "        if 1 in ((event_timeseries.iloc[ind-lookback_window:ind+late_window]['ews'] >= percent_to_threshold*voting_threshold)*1).values:\n",
    "            if verbose:\n",
    "                print('False Negative classified as, Missed Event, EWS close to threshold (%{0} or closer)'.format(percent_to_threshold))\n",
    "                print('EWS Timeseries =>',event_timeseries.iloc[ind-lookback_window:ind+late_window+1]['ews'])\n",
    "                print('Voting Threshold =>', voting_threshold)\n",
    "            n_missed_close_threshold += 1  \n",
    "            event_date_labels.append(['close',date_target_activation,''])\n",
    "\n",
    "        else:\n",
    "            if verbose:\n",
    "                print('False Negative classified as missed event without relevant EWS activity')\n",
    "            n_missed +=1\n",
    "            event_date_labels.append(['missed',date_target_activation,''])'''\n",
    "\n",
    "    # if  FP and lambda > 1 within time window, then Activity increase but no Event\n",
    "    # Add earliness\n",
    "\n",
    "    if verbose:\n",
    "        print('\\n\\n\\n')\n",
    "        print('Analyzing False Positives')\n",
    "    \n",
    "    \n",
    "    for date_predictor_activation in events['fp']:\n",
    "        ind = event_timeseries.index.get_loc(date_predictor_activation)\n",
    "        # ignore  events that occur after event happened\n",
    "        if ind - ind_tp > 0:\n",
    "            if verbose: print('ignoring fp due to it ocurring after tp/fn')\n",
    "            continue\n",
    "        if verbose:\n",
    "            print('\\n\\n')\n",
    "            print('Date =>', date_predictor_activation)\n",
    "            print('Activity =>', event_timeseries.iloc[ind-lookback_window:ind+late_window+1][['target_original', 'target_lambda','ews', 'ews_thresholded']])\n",
    "\n",
    "        if 1 in (event_timeseries.iloc[ind:ind+lookback_window+1][['target_lambda']].values >= threshold)*1:\n",
    "\n",
    "            if verbose: print('Event classified as warning associated to target increase (but no event)')\n",
    "            n_warning_increase_but_no_outbreak += 1\n",
    "            event_date_labels.append(['soft event','',date_predictor_activation])\n",
    "\n",
    "        else:\n",
    "            if verbose: print('Event classified as false Alarm')\n",
    "            n_false_alarm += 1\n",
    "            event_date_labels.append(['false alarm','',date_predictor_activation])\n",
    "\n",
    "    return [n_early, n_sync, n_late, n_warning_increase_but_no_outbreak, n_missed_close_threshold, n_false_alarm, n_missed], event_date_labels\n",
    "\n",
    "def perform_analysis_singlesource(experiment_name, verbose=False, site_names = constants.site_names[50:], threshold = 1, filter_date=None):\n",
    "    \n",
    "    # Benchmark Table\n",
    "    analysis_col_names = ['Early Warning', 'Sync Warning', 'Late Warning', 'Warning Associated to Activity but no Event', 'Event and EWS close to threshold', 'False Alarm', 'Missed Outbreaks']\n",
    "    site_name = 'FL_Palm Beach'\n",
    "\n",
    "    site_analyses = []\n",
    "    dataframe_analyses = []\n",
    "    earliness_list = []\n",
    "\n",
    "\n",
    "    #---- Function Params\n",
    "    lookback_window = 6\n",
    "    late_window = 2\n",
    "    percent_to_threshold = .5\n",
    "\n",
    "\n",
    "    for site_name in site_names:\n",
    "\n",
    "        files = sorted(os.listdir(path_output_results + '/{0}/pickle'.format(experiment_name)))\n",
    "        event_pickle_names = [f for f in files if '{0}'.format(site_name) == f[:len(site_name)]]\n",
    "        event_pickle_names\n",
    "\n",
    "        analyses = []\n",
    "        for filename_pickle in event_pickle_names:\n",
    "\n",
    "            try:\n",
    "                #---- Loading Event Data\n",
    "                fit_data = pickle.load(open(path_output_results + '/{0}/pickle/{1}'.format(experiment_name, filename_pickle),'rb'))\n",
    "                events = fit_data['performance'][late_window]\n",
    "                \n",
    "                \n",
    "                if len(events['fn']) > 0:\n",
    "                    print(filename_pickle, 'fn here')\n",
    "                \n",
    "                \n",
    "                if len(events['tp']) + len(events['fn']) > 1:\n",
    "                    print('More than 1 event in {0}'.format(filename_pickle))\n",
    "                \n",
    "                \n",
    "                \n",
    "                if filter_date:\n",
    "                    \n",
    "                    if len(events['tp']) > 0:\n",
    "                        \n",
    "                        if filter_date == 'after':\n",
    "                            condition = events['tp'][0][0] < '2021-12-01' #filter_date\n",
    "                        elif filter_date == 'before':\n",
    "                            condition = events['tp'][0][0] >= '2021-12-01'  \n",
    "                            \n",
    "                        \n",
    "                        if condition:\n",
    "                            print('skipping date =>', events['tp'][0])\n",
    "                            continue\n",
    "                    \n",
    "                    elif len(events['fn']) > 0:\n",
    "                        \n",
    "                        if filter_date == 'after':\n",
    "                            condition = events['fn'][0] < '2021-12-01' #filter_date\n",
    "                        elif filter_date == 'before':\n",
    "                            condition = events['fn'][0] >= '2021-12-01' \n",
    "                            \n",
    "                        if  condition: \n",
    "                            print('skipping date =>', events['fn'][0])\n",
    "                            continue\n",
    "                    \n",
    "                            \n",
    "                for tup in events['tp']:\n",
    "                    earliness_list.append((site_name, tup[0], tup[1], tup[2]))\n",
    "                event_timeseries = fit_data['timeseries']\n",
    "                \n",
    "                if verbose:\n",
    "                    print('events =>', events)\n",
    "\n",
    "                if isinstance(fit_data['model'], EWS_Discrete_Ensemble):\n",
    "                    voting_threshold = fit_data['model'].voting_threshold\n",
    "                elif isinstance(fit_data['model'], EWS_Discrete):\n",
    "                    voting_threshold = fit_data['model'].beta\n",
    "\n",
    "\n",
    "                #----\n",
    "                analysis, labs = event_analysis_singlesource(events,\n",
    "                                   event_timeseries,\n",
    "                                   voting_threshold,\n",
    "                                   lookback_window,\n",
    "                                   late_window,\n",
    "                                   percent_to_threshold,\n",
    "                                   verbose=verbose,\n",
    "                                    threshold = threshold\n",
    "                                  )\n",
    "                analyses.append(analysis)\n",
    "                labs = pd.DataFrame(labs, columns = ['Event Type', 'Target Date', 'Predictor Date'])\n",
    "                labs['site_name'] = site_name\n",
    "                labs['n_events_training'] = filename_pickle[-5] \n",
    "                dataframe_analyses.append(labs.copy())\n",
    "            except Exception as t:\n",
    "                print('Error with {0}'.format(site_name))\n",
    "                print(t)\n",
    "\n",
    "        analyses = pd.DataFrame(analyses,columns=analysis_col_names)\n",
    "        #analyses.index.name = site_name\n",
    "        site_analyses.append(analyses)\n",
    "\n",
    "    dataframe_analyses = pd.concat(dataframe_analyses, axis=0)\n",
    "    df = pd.concat(site_analyses, axis=0).sum().T.to_frame()\n",
    "    df.columns = [experiment_name]\n",
    "    \n",
    "    \n",
    "    list_sum =['Early Warning',\n",
    "    'Sync Warning',\n",
    "    'Late Warning',\n",
    "    'Event and EWS close to threshold',\n",
    "    'Missed Outbreaks',\n",
    "    ]\n",
    "\n",
    "    list_not_percent = ['Warning Associated to Activity but no Event', 'False Alarm']\n",
    "    total_events = 0\n",
    "\n",
    "    for loc in list_sum:\n",
    "        total_events += df[[experiment_name]].loc[loc].values\n",
    "\n",
    "    df = df.T\n",
    "    df['Total Events'] = total_events\n",
    "    df = df.T\n",
    "    df['percent']=((df[experiment_name].values.ravel()/total_events)*100)\n",
    "    for loc in list_not_percent:\n",
    "        df.loc[loc,'percent'] = float('NaN')\n",
    "\n",
    "    text_vals = []\n",
    "\n",
    "    for i, row in df.iterrows():\n",
    "        if np.isnan(np.round(row['percent'], 1)):\n",
    "            text_vals.append('{0}'.format(row[experiment_name]))\n",
    "        else:\n",
    "            text_vals.append('{0} ({1}%)'.format(row[experiment_name], np.round(row['percent'], 1)))\n",
    "    \n",
    "    df['Results'] = text_vals\n",
    "    return df, earliness_list\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('UpToDate',)\n",
      "['up2date']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Event 2022-01-01 close to boundary\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n",
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/4j/bxl_8jd95j13bvx64wvp8kc00000gn/T/ipykernel_9284/968188993.py:910: UserWarning: Max number of predictors is => 6 but training data only contains 1 fit predictors.\n",
      "  warnings.warn('Max number of predictors is => {0} but training data only contains {1} fit predictors.'.format(max_n_predictors, total_n_predictors))\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [25]\u001b[0m, in \u001b[0;36m<cell line: 115>\u001b[0;34m()\u001b[0m\n\u001b[1;32m    247\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m model_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mensemble\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m    248\u001b[0m     ews \u001b[38;5;241m=\u001b[39m EWS_Discrete_Ensemble(metric_to_optimize_beta\u001b[38;5;241m=\u001b[39mmetric_to_optimize_beta, \u001b[38;5;66;03m# EWS Ensemble\u001b[39;00m\n\u001b[1;32m    249\u001b[0m                                 seed\u001b[38;5;241m=\u001b[39mseed\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m    250\u001b[0m                                 max_n_ews\u001b[38;5;241m=\u001b[39m max_n_voters,\n\u001b[1;32m    251\u001b[0m                                 split\u001b[38;5;241m=\u001b[39msplit) \n\u001b[0;32m--> 253\u001b[0m \u001b[43mews\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_training\u001b[49m\u001b[43m,\u001b[49m\u001b[43mY_training\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    255\u001b[0m \u001b[38;5;66;03m# Predict ------\u001b[39;00m\n\u001b[1;32m    256\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m: \n\u001b[1;32m    257\u001b[0m     \u001b[38;5;66;03m# TODO Move half voting threshold to model \u001b[39;00m\n",
      "Input \u001b[0;32mIn [19]\u001b[0m, in \u001b[0;36mEWS_Discrete.fit\u001b[0;34m(self, X_training, Y_training, verbose)\u001b[0m\n\u001b[1;32m    889\u001b[0m     predictor_individual_performance[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mm2\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(m2)\n\u001b[1;32m    890\u001b[0m     predictor_individual_performance[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mm3\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mappend(m3)\n\u001b[0;32m--> 892\u001b[0m predictor_individual_performance \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mDataFrame\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredictor_individual_performance\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    893\u001b[0m predictor_individual_performance\u001b[38;5;241m.\u001b[39mset_index(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mname_predictor\u001b[39m\u001b[38;5;124m'\u001b[39m, inplace\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)  \n\u001b[1;32m    894\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m verbose: \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfitting phase 2\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[0;32m~/Desktop/SANTILLANA/JJ-Dengue-Forecasting/env/lib/python3.9/site-packages/pandas/core/frame.py:636\u001b[0m, in \u001b[0;36mDataFrame.__init__\u001b[0;34m(self, data, index, columns, dtype, copy)\u001b[0m\n\u001b[1;32m    630\u001b[0m     mgr \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_mgr(\n\u001b[1;32m    631\u001b[0m         data, axes\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mindex\u001b[39m\u001b[38;5;124m\"\u001b[39m: index, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumns\u001b[39m\u001b[38;5;124m\"\u001b[39m: columns}, dtype\u001b[38;5;241m=\u001b[39mdtype, copy\u001b[38;5;241m=\u001b[39mcopy\n\u001b[1;32m    632\u001b[0m     )\n\u001b[1;32m    634\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m    635\u001b[0m     \u001b[38;5;66;03m# GH#38939 de facto copy defaults to False only in non-dict cases\u001b[39;00m\n\u001b[0;32m--> 636\u001b[0m     mgr \u001b[38;5;241m=\u001b[39m \u001b[43mdict_to_mgr\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolumns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtyp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmanager\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    637\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, ma\u001b[38;5;241m.\u001b[39mMaskedArray):\n\u001b[1;32m    638\u001b[0m     \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mma\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmrecords\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mmrecords\u001b[39;00m\n",
      "File \u001b[0;32m~/Desktop/SANTILLANA/JJ-Dengue-Forecasting/env/lib/python3.9/site-packages/pandas/core/internals/construction.py:502\u001b[0m, in \u001b[0;36mdict_to_mgr\u001b[0;34m(data, index, columns, dtype, typ, copy)\u001b[0m\n\u001b[1;32m    494\u001b[0m     arrays \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m    495\u001b[0m         x\n\u001b[1;32m    496\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(x, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdtype\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(x\u001b[38;5;241m.\u001b[39mdtype, ExtensionDtype)\n\u001b[1;32m    497\u001b[0m         \u001b[38;5;28;01melse\u001b[39;00m x\u001b[38;5;241m.\u001b[39mcopy()\n\u001b[1;32m    498\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m arrays\n\u001b[1;32m    499\u001b[0m     ]\n\u001b[1;32m    500\u001b[0m     \u001b[38;5;66;03m# TODO: can we get rid of the dt64tz special case above?\u001b[39;00m\n\u001b[0;32m--> 502\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43marrays_to_mgr\u001b[49m\u001b[43m(\u001b[49m\u001b[43marrays\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolumns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtyp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtyp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconsolidate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Desktop/SANTILLANA/JJ-Dengue-Forecasting/env/lib/python3.9/site-packages/pandas/core/internals/construction.py:156\u001b[0m, in \u001b[0;36marrays_to_mgr\u001b[0;34m(arrays, columns, index, dtype, verify_integrity, typ, consolidate)\u001b[0m\n\u001b[1;32m    153\u001b[0m axes \u001b[38;5;241m=\u001b[39m [columns, index]\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m typ \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblock\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 156\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcreate_block_manager_from_column_arrays\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    157\u001b[0m \u001b[43m        \u001b[49m\u001b[43marrays\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconsolidate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconsolidate\u001b[49m\n\u001b[1;32m    158\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    159\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m typ \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124marray\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m    160\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m ArrayManager(arrays, [index, columns])\n",
      "File \u001b[0;32m~/Desktop/SANTILLANA/JJ-Dengue-Forecasting/env/lib/python3.9/site-packages/pandas/core/internals/managers.py:1954\u001b[0m, in \u001b[0;36mcreate_block_manager_from_column_arrays\u001b[0;34m(arrays, axes, consolidate)\u001b[0m\n\u001b[1;32m   1937\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_block_manager_from_column_arrays\u001b[39m(\n\u001b[1;32m   1938\u001b[0m     arrays: \u001b[38;5;28mlist\u001b[39m[ArrayLike],\n\u001b[1;32m   1939\u001b[0m     axes: \u001b[38;5;28mlist\u001b[39m[Index],\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1950\u001b[0m     \u001b[38;5;66;03m# These last three are sufficient to allow us to safely pass\u001b[39;00m\n\u001b[1;32m   1951\u001b[0m     \u001b[38;5;66;03m#  verify_integrity=False below.\u001b[39;00m\n\u001b[1;32m   1953\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1954\u001b[0m         blocks \u001b[38;5;241m=\u001b[39m \u001b[43m_form_blocks\u001b[49m\u001b[43m(\u001b[49m\u001b[43marrays\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconsolidate\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1955\u001b[0m         mgr \u001b[38;5;241m=\u001b[39m BlockManager(blocks, axes, verify_integrity\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m   1956\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "File \u001b[0;32m~/Desktop/SANTILLANA/JJ-Dengue-Forecasting/env/lib/python3.9/site-packages/pandas/core/internals/managers.py:2028\u001b[0m, in \u001b[0;36m_form_blocks\u001b[0;34m(arrays, consolidate)\u001b[0m\n\u001b[1;32m   2025\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(dtype\u001b[38;5;241m.\u001b[39mtype, (\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mbytes\u001b[39m)):\n\u001b[1;32m   2026\u001b[0m     dtype \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mdtype(\u001b[38;5;28mobject\u001b[39m)\n\u001b[0;32m-> 2028\u001b[0m values, placement \u001b[38;5;241m=\u001b[39m _stack_arrays(\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtup_block\u001b[49m\u001b[43m)\u001b[49m, dtype)\n\u001b[1;32m   2029\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_dtlike:\n\u001b[1;32m   2030\u001b[0m     values \u001b[38;5;241m=\u001b[39m ensure_wrapped_if_datetimelike(values)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Model Fitting Routine\n",
    "def filter_dataset(dataset_combination, name_predictors_for_fitting):\n",
    "    epi = []\n",
    "    mobility = []\n",
    "    twitter = []\n",
    "    gt = []\n",
    "    neighbors = []\n",
    "    uptodate =[]\n",
    "\n",
    "    for t in name_predictors_for_fitting:\n",
    "        if t in ['apple_mobility_by_driving',\n",
    "    'apple_mobility_by_walking',\n",
    "    'apple_mobility_by_transit',\n",
    "                 'apple_mobility_by_driving_state',\n",
    "                 'apple_mobility_by_walking_state',\n",
    "                 'apple_mobility_by_transit_state',\n",
    "                 \n",
    "                ]:\n",
    "            mobility.append(t)\n",
    "        elif 'twitter' in t:\n",
    "            twitter.append(t)\n",
    "        elif 'gt2_' in t or 'gt_' in t:\n",
    "            gt.append(t)\n",
    "        elif 'neighbor_' in t:\n",
    "            neighbors.append(t)\n",
    "        elif 'up2date' in t:\n",
    "            uptodate.append(t)\n",
    "        else:\n",
    "            epi.append(t)\n",
    "\n",
    "    filtered_dataset = []\n",
    "    for source in dataset_combination:\n",
    "\n",
    "        if source == 'local_epi':\n",
    "            filtered_dataset += epi\n",
    "        if source == 'UpToDate':\n",
    "            filtered_dataset += uptodate\n",
    "        if source == 'Apple Mobility':\n",
    "            filtered_dataset += mobility\n",
    "        if source == 'Twitter':\n",
    "            filtered_dataset += twitter\n",
    "        if source == 'Neighbor Data':\n",
    "            filtered_dataset += neighbors\n",
    "        if source == 'Google Searches':\n",
    "            filtered_dataset += gt\n",
    "\n",
    "    return filtered_dataset\n",
    "\n",
    "\n",
    "sources = {\n",
    "    0:'local_epi',\n",
    "    1:'Google Searches',\n",
    "    2:'Neighbor Data',\n",
    "    3:'Twitter',\n",
    "    4:'Apple Mobility',\n",
    "    5:'UpToDate'\n",
    "}\n",
    "combinations = []\n",
    "for v in product([0,1], repeat=len(sources)):\n",
    "    combinations.append(v)   \n",
    "datasets = []\n",
    "for combination in combinations[1:]:\n",
    "    #print(combination)\n",
    "    dataset = []\n",
    "    for i, val in enumerate(combination):\n",
    "        if val == 1:\n",
    "            #print(i)\n",
    "            dataset.append(sources[i])\n",
    "    #print(dataset)\n",
    "    datasets.append(tuple(dataset))\n",
    "            \n",
    "dataset_combinations = datasets  \n",
    "\n",
    "\n",
    "# ARGO LIKE TEST V2\n",
    "#---- EWS Params\n",
    "\n",
    "\n",
    "\n",
    "whitelisted_terms = ['gt_after covid vaccine', 'gt_side effects of vaccine', 'gt_effects of covid vaccine', 'gt_covid', 'gt_how long does covid last', 'gt_anosmia', 'gt_loss smell', 'gt_covid-19', 'gt_loss taste', 'gt_loss of smell', 'gt_chest pain', 'gt_covid symptoms', 'gt_sars-cov 2', 'gt_chest tightness', 'gt_covid nhs', 'gt_quarantine', 'gt_covid-19 who', 'gt_sars-cov-2 ', 'gt_feeling exhausted', 'gt_nose bleed', 'gt_feeling tired', 'gt_joints aching', 'gt_fever', 'gt2_Abdominal pain', 'gt2_Acute bronchitis', 'gt2_Ageusia', 'gt2_Anosmia', 'gt2_Anxiety', 'gt2_Asphyxia', 'gt2_Asthma', 'gt2_Bronchitis', 'gt2_Burning Chest Pain', 'gt2_Cardiac arrest', 'gt2_Chest pain', 'gt2_Chills', 'gt2_Chronic pain', 'gt2_Cough', 'gt2_Diarrhea', 'gt2_Dizziness', 'gt2_Dysgeusiagt2_Ear pain', 'gt2_Fatigue', 'gt2_Fever', 'gt2_Hyperventilation', 'gt2_Nasal congestion', 'gt2_Otitis', 'gt2_Phlegm', 'gt2_Pneumonia', 'gt2_Shortness of breath', 'gt2_Sore throat', 'gt2_Throat irritation', 'gt2_Vertigo', 'gt2_Viral pneumonia', 'gt2_Vomiting', 'gt2_Wheeze']\n",
    "# whitelisted_terms = ['gt_after covid vaccine', 'gt_side effects of vaccine', 'gt_effects of covid vaccine', 'gt_covid', 'gt_how long does covid last', 'gt_anosmia', 'gt_loss smell', 'gt_covid-19', 'gt_loss taste', 'gt_loss of smell', 'gt_chest pain', 'gt_covid symptoms', 'gt_sars-cov 2', 'gt_chest tightness', 'gt_covid nhs', 'gt_quarantine', 'gt_covid-19 who', 'gt_sars-cov-2 ', 'gt_feeling exhausted', 'gt_nose bleed', 'gt_feeling tired', 'gt_joints aching', 'gt_fever', 'gt2_Abdominal pain', 'gt2_Acute bronchitis', 'gt2_Ageusia', 'gt2_Anosmia', 'gt2_Anxiety', 'gt2_Asphyxia', 'gt2_Asthma', 'gt2_Back pain', 'gt2_Breast pain', 'gt2_Bronchitis', 'gt2_Burn', 'gt2_Burning Chest Pain', 'gt2_Cardiac arrest', 'gt2_Chest pain', 'gt2_Chills', 'gt2_Chronic pain', 'gt2_Cough', 'gt2_Diarrhea', 'gt2_Dizziness', 'gt2_Ear pain', 'gt2_Fatigue', 'gt2_Fever', 'gt2_Hyperventilation', 'gt2_Major depressive disorder', 'gt2_Nasal congestion', 'gt2_Otitis', 'gt2_Phlegm', 'gt2_Pneumonia', 'gt2_Post-nasal drip', 'gt2_Sexual dysfunction', 'gt2_Shallow breathing', 'gt2_Sharp pain', 'gt2_Shivering', 'gt2_Shortness of breath', 'gt2_Sinusitis', 'gt2_Skin condition', 'gt2_Skin rash', 'gt2_Skin ulcer', 'gt2_Sleep apnea', 'gt2_Sleep deprivation', 'gt2_Sleep disorder', 'gt2_Sore throat', 'gt2_Throat irritation', 'gt2_Thrombocytopenia', 'gt2_Vasculitis', 'gt2_Ventricular fibrillation', 'gt2_Ventricular tachycardia', 'gt2_Vertigo', 'gt2_Viral pneumonia', 'gt2_Visual acuity', 'gt2_Vomiting', 'gt2_Weakness', 'gt2_Wheeze', 'gt2_Xeroderma', 'gt2_Xerostomia', 'gt2_Yawn', 'gt2_hyperhidrosis', 'gt2_pancreatitis']\n",
    "\n",
    "\n",
    "current_date = datetime.datetime.now().strftime('%Y-%m-%d')\n",
    "verbose = False\n",
    "lookback_window = 4\n",
    "lookback_window_ews = 6\n",
    "split_training = False\n",
    "max_n_voters = float('inf')\n",
    "'''\n",
    "m1 = tp/(tp+fp+fn)\n",
    "m2 = tp/(tp+sfp+fn)\n",
    "m3 = (m1+m2)/2\n",
    "'''\n",
    "metric_to_optimize_beta = 'm3'  \n",
    "\n",
    "#---- Fitting routing params\n",
    "overwrite_folder = True\n",
    "type_lambda = 'weekly_3d'\n",
    "late_tolerance = [-1,0,1,2]\n",
    "model_type = 'ews' #'ensemble', 'ews'\n",
    "name_signal = 'rt_parag_probability' # 'rt_cori', 'rt_parag', 'JHU_' (jhu is lambda)\n",
    "source_type = 'cases' # 'cases', 'deaths'\n",
    "close_events = []\n",
    "geo = 'state'\n",
    "\n",
    "if geo == 'state':\n",
    "    site_names = constants.site_names[:50]\n",
    "elif geo == 'county':\n",
    "    site_names = constants.site_names[50:] \n",
    "    \n",
    "#site_names = site_names[:1]\n",
    "\n",
    "dataset_combinations_gt = [c for c in dataset_combinations if 'Google Searches' not  in c]\n",
    "for dataset_combination in dataset_combinations[:3]:\n",
    "    print(dataset_combination)\n",
    "    for model_type in ['ews']: #'ews', 'ensemble'\n",
    "        for name_signal in ['rt_parag_probability']:#[ 'rt_parag_probability']:\n",
    "            for source_type in ['cases']:\n",
    "\n",
    "                name_target = '{0}_{1}'.format(name_signal, source_type) # 'JHU_{0}'.format(source_type) #_corrected\n",
    "                experiment_name = 'reproducibility_{4}_{3}'.format(current_date, name_target, model_type, geo, dataset_combination)\n",
    "                try:\n",
    "                    create_experiment_folders(experiment_name)\n",
    "                except Exception as t:\n",
    "                    if overwrite_folder:\n",
    "                        pass\n",
    "                    else:\n",
    "                        sys.exit(-1)\n",
    "\n",
    "\n",
    "                # Error Events:\n",
    "                error_tracker = {\n",
    "                    'site_name':[],\n",
    "                    'n_events_training':[],\n",
    "                    'error_type': []\n",
    "                }\n",
    "                max_tol = np.max(late_tolerance)\n",
    "                printdataset = True\n",
    "\n",
    "\n",
    "                for seed,site_name in enumerate(site_names):\n",
    "                    try:\n",
    "                        if verbose:\n",
    "                            print('predicting for =>', site_name)\n",
    "                            print('seed =>', seed)\n",
    "\n",
    "                        #---- Load data\n",
    "                        timeseries_all_data = pd.read_csv(PATH_DATASET_EWS.format(geo) + '/{0}_preprocessed.csv'.format(site_name), index_col=0)\n",
    "                        df_events = pd.read_csv(PATH_DATASET_EWS.format(geo) + '/{0}_event_dates.csv'.format(site_name), index_col=0)\n",
    "                        timeseries_lambda_allproxies = pd.read_csv(PATH_DATASET_EWS.format(geo) + '/{0}_lambda.csv'.format(site_name), index_col=0)\n",
    "                        timeseries_events_binary = pd.read_csv(PATH_DATASET_EWS.format(geo) + '/{0}_binary.csv'.format(site_name), index_col=0)\n",
    "                        dates = list(timeseries_events_binary.index.values)\n",
    "                        name_predictors_for_fitting = list(timeseries_events_binary)\n",
    "                        if name_target in name_predictors_for_fitting:\n",
    "                            name_predictors_for_fitting.remove(name_target) \n",
    "\n",
    "                        if 'dummy' in name_predictors_for_fitting:\n",
    "                            name_predictors_for_fitting.remove('dummy')\n",
    "\n",
    "                        name_predictors_for_fitting = filter_dataset(dataset_combination, name_predictors_for_fitting)\n",
    "                        \n",
    "                        \n",
    "                        # Whitelisting gt\n",
    "                        t_to_remove = []\n",
    "                        for t in name_predictors_for_fitting:\n",
    "                            if ('gt2_' in t or 'gt_' in t) and t not in whitelisted_terms:\n",
    "                                #print('removing', t)\n",
    "                                t_to_remove.append(t)\n",
    "\n",
    "                        for t in t_to_remove:\n",
    "                            name_predictors_for_fitting.remove(t)\n",
    "                        \n",
    "                        \n",
    "                        if printdataset:\n",
    "\n",
    "                            print(name_predictors_for_fitting)\n",
    "                            printdataset = False\n",
    "\n",
    "                    \n",
    "\n",
    "                        #name_predictors_for_fitting = [t for t in name_predictors_for_fitting if 'neighbor_' not in t]\n",
    "                        #name_predictors_for_fitting = [t for t in name_predictors_for_fitting if 'gt2_' not in t]\n",
    "                        #print(name_predictors_for_fitting)\n",
    "\n",
    "                        #----Counting Number of Events    \n",
    "                        event_dates_gold_standard =  df_events[df_events['name_timeseries'] == name_target].sort_values('date_start')    \n",
    "                        n_events = event_dates_gold_standard.shape[0]\n",
    "                        site_overall_performance = []\n",
    "                        index = []\n",
    "\n",
    "                        if verbose:\n",
    "                            print('event gold standard =>', event_dates_gold_standard)\n",
    "                            print('N events in total=>', n_events)\n",
    "                        #---- Training / Predict Loop\n",
    "\n",
    "\n",
    "                         # EXP 3: only test on wave 3 and above, excluding first wave activation\n",
    "                        # Don't forget to change range(1, n_events) => range(2,n_events)!!!\n",
    "\n",
    "                        '''\n",
    "                        if n_events > 2:\n",
    "                            print('{0} with {1} events'.format(site_name, n_events))\n",
    "                            print('Removing first event from gold_standard')\n",
    "                            v = timeseries_events_binary[name_target].values.ravel()\n",
    "                            indices_events = [ind for ind, val in enumerate(v) if val == 1]\n",
    "                            v[indices_events[0]] = 0 # removing\n",
    "                            timeseries_events_binary[name_target] = v\n",
    "                            print('succesfully removed')\n",
    "                        else:\n",
    "                            continue '''\n",
    "\n",
    "                        # Exp 2: only predict second wave wave\n",
    "                        ''''\n",
    "                        if n_events > 2:\n",
    "                            n_events = 2\n",
    "                        '''\n",
    "\n",
    "                        for n_events_training in range(1,n_events):\n",
    "\n",
    "                            date_start_training_dataset, date_end_training_dataset, date_start_test_dataset, date_end_test_dataset, date_event = get_training_test_dates(n_events_training,\n",
    "                                                event_dates_gold_standard,\n",
    "                                                timeseries_events_binary,\n",
    "                                                forward_window=3,\n",
    "                                               verbose=verbose)\n",
    "\n",
    "                            if verbose:\n",
    "                                print('N_events Training =>', n_events_training)\n",
    "                                print('\\n\\n')\n",
    "\n",
    "\n",
    "                            dataset_training = timeseries_events_binary.loc[date_start_training_dataset:date_end_training_dataset].copy()\n",
    "                            dataset_test = timeseries_events_binary.loc[date_start_test_dataset:date_end_test_dataset].copy()\n",
    "                            X_training = dataset_training[name_predictors_for_fitting]\n",
    "                            Y_training = dataset_training[[name_target]]\n",
    "                            X_predict = dataset_test[name_predictors_for_fitting]\n",
    "                            Y_predict = dataset_test[[name_target]]\n",
    "\n",
    "                            #---- Fitting Model\n",
    "                            if n_events_training > 1 and split_training:\n",
    "                                split = True\n",
    "                            else:\n",
    "                                split = False\n",
    "\n",
    "                            if model_type == 'ews':\n",
    "                                ews = EWS_Discrete(metric_to_optimize_beta=metric_to_optimize_beta) # SINGLE EWS\n",
    "                            elif model_type == 'ensemble':\n",
    "                                ews = EWS_Discrete_Ensemble(metric_to_optimize_beta=metric_to_optimize_beta, # EWS Ensemble\n",
    "                                                            seed=seed+1,\n",
    "                                                            max_n_ews= max_n_voters,\n",
    "                                                            split=split) \n",
    "\n",
    "                            ews.fit(X_training,Y_training)\n",
    "                            \n",
    "                            # Predict ------\n",
    "                            try: \n",
    "                                # TODO Move half voting threshold to model \n",
    "\n",
    "                                ews_timeseries = ews.predict(X_predict)\n",
    "                                \n",
    "\n",
    "                                \n",
    "\n",
    "                            except Exception as t:\n",
    "                                #print(t)\n",
    "                                #print('failed to predict n_event:{0} for site_name:{1}'.format(n_events_training, site_name))\n",
    "                                error_tracker['site_name'].append(site_name)\n",
    "                                error_tracker['n_events_training'] = n_events_training\n",
    "                                error_tracker['error_type'] = 'Predict Failure' \n",
    "                                error_tracker['error'] = t \n",
    "                                continue\n",
    "\n",
    "                            #---- Benchmarking data\n",
    "                            df_to_benchmark = pd.concat([ews_timeseries[['ews_thresholded']], Y_predict], axis=1)\n",
    "                            df_to_benchmark.columns = ['x','y']\n",
    "\n",
    "                            df_to_benchmark['y'] = 0\n",
    "                            df_to_benchmark.loc[date_event, 'y'] = 1\n",
    "\n",
    "\n",
    "                            if verbose:\n",
    "                                    print('Test',dataset_test[name_target])\n",
    "                                    print('training', dataset_training.index.values)\n",
    "                                    print('X_predict index =>',ews_timeseries.index.values)\n",
    "                                    print('Y_predict_index =>', Y_predict.index.values)\n",
    "                                    print('df_to_benchmark index =>', df_to_benchmark.index.values)\n",
    "\n",
    "\n",
    "                            # Testing for different tolerances:\n",
    "                            labeled_events_test_by_tolerance = {}\n",
    "                            labeling_error = False\n",
    "                            for tol in late_tolerance:\n",
    "                                labeled_events_test = label_events(df_to_benchmark,\n",
    "                                               lookback_window=lookback_window_ews,\n",
    "                                                tp_late_tolerance=tol,\n",
    "                                             verbose=False)\n",
    "\n",
    "                                if verbose:\n",
    "                                    print('\\n\\n')\n",
    "                                    print('Checking for Tol => {0} '.format(tol))\n",
    "                                    print('Events =>',labeled_events_test)\n",
    "\n",
    "                                # Check for events which happen too early and mark them as not benchmarkable\n",
    "                                #(i.e. event happened at date indexed at location 2 but lookback window for predictors > 2)\n",
    "                                if len(labeled_events_test['tp']) == 0 and len(labeled_events_test['fn']) == 0:\n",
    "                                    print('Event Labeling Failure: event may be too early => {0}- {1}'.format(site_name, n_events_training))\n",
    "                                    error_tracker['site_name'].append(site_name)\n",
    "                                    error_tracker['n_events_training'] = n_events_training\n",
    "                                    error_tracker['error_type'] = 'Event Labeling Failure: event may be too early'\n",
    "                                    labeling_error = True\n",
    "\n",
    "\n",
    "\n",
    "                                    raise Exception('hi')\n",
    "\n",
    "                                    continue\n",
    "\n",
    "                                index.append(n_events_training)\n",
    "                                labeled_events_test = add_sfps(ews_timeseries[['ews_thresholded']],df_to_benchmark,labeled_events_test, verbose=False)\n",
    "                                labeled_events_test_by_tolerance[tol] = copy.copy(labeled_events_test)\n",
    "\n",
    "                                if verbose: \n",
    "                                    print('n_events training', n_events_training)\n",
    "                                    print('tol =>', tol)\n",
    "                                    print('labeled events =>')\n",
    "                                    print(labeled_events_test)\n",
    "\n",
    "                            if labeling_error: continue  \n",
    "                            # gs_data, gs_lambda_data, gs_event_data, ews, ews_thresholded\n",
    "                            d1 = ews_timeseries.index.values[0]\n",
    "                            d2 = ews_timeseries.index.values[-1]\n",
    "\n",
    "\n",
    "                            timeseries_out_of_sample = pd.concat([timeseries_all_data[[name_target]].loc[d1:d2],\n",
    "                                       timeseries_lambda_allproxies[[name_target]].loc[d1:d2],\n",
    "                                       Y_predict,\n",
    "                                       ews_timeseries[['ews', 'ews_thresholded']]\n",
    "                                      ], axis=1)\n",
    "                            timeseries_out_of_sample.columns = ['target_original', 'target_lambda', 'target_event', 'ews', 'ews_thresholded']\n",
    "                            #timeseries_out_of_sample\n",
    "                            timeseries_out_of_sample['target_event'] = df_to_benchmark['y']\n",
    "\n",
    "                            #----- Saving Data\n",
    "                            experiment_meta = {\n",
    "                                 'site_name': site_name,\n",
    "                                'lambda_type':type_lambda,\n",
    "                                'n_training_events': n_events_training,\n",
    "                                 'training_start_date': timeseries_events_binary.index.values[0],\n",
    "                                 'training_end_date': date_end_training_dataset,\n",
    "                                 'test_start_date': date_start_test_dataset,\n",
    "                                 'test_end_date,': date_end_test_dataset,\n",
    "                                 'model':ews,\n",
    "                                 'performance': labeled_events_test_by_tolerance,\n",
    "                                'name_target':name_target,\n",
    "                                'timeseries':timeseries_out_of_sample\n",
    "                                 }\n",
    "\n",
    "                            pickle.dump(experiment_meta,open(path_output_results+  '/{2}/pickle/{0}_{1}.csv'.format(site_name, n_events_training, experiment_name), 'wb'))        \n",
    "\n",
    "\n",
    "                            if verbose:\n",
    "                                print('labeled Events by tol =>', labeled_events_test_by_tolerance)\n",
    "                            if isinstance(ews, EWS_Discrete):\n",
    "                                # EWS Model\n",
    "                                for tol in late_tolerance:\n",
    "                                    site_overall_performance.append(events2metrics(labeled_events_test_by_tolerance[tol]) + [ews.beta] + [tol])\n",
    "                            elif isinstance(ews, EWS_Discrete_Ensemble):\n",
    "                                # EWS Ensemble Model \n",
    "                                for tol in late_tolerance:\n",
    "                                    site_overall_performance.append(events2metrics(labeled_events_test_by_tolerance[tol]) + [None] + [tol])\n",
    "                        site_overall_performance = pd.DataFrame(site_overall_performance, columns = ['tp', 'fp', 'tn', 'fn','sfp', 'tpr', 'precision', 'm1','m2','m3','earliness', r'$\\beta$', 'late_activation_tolerance'], index =index)\n",
    "                        site_overall_performance.index.name =('N training events')\n",
    "                        site_overall_performance['site_name'] = site_name\n",
    "                        site_overall_performance.to_csv(path_output_results + '/{1}/performance_per_fit/{0}.csv'.format(site_name, experiment_name))\n",
    "                    except Exception as t:\n",
    "                        print(t)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "reproducibility_('UpToDate',)_state\n",
      "reproducibility_('Apple Mobility',)_state\n",
      "reproducibility_('Apple Mobility', 'UpToDate')_state\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Total Events</th>\n",
       "      <th>Early Warning</th>\n",
       "      <th>Sync Warning</th>\n",
       "      <th>Late Warning</th>\n",
       "      <th>Event and EWS close to threshold</th>\n",
       "      <th>Missed Outbreaks</th>\n",
       "      <th>Warning Associated to Activity but no Event</th>\n",
       "      <th>False Alarm</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>reproducibility_('Apple Mobility',)_state</th>\n",
       "      <td>172</td>\n",
       "      <td>75</td>\n",
       "      <td>14</td>\n",
       "      <td>5</td>\n",
       "      <td>21</td>\n",
       "      <td>57</td>\n",
       "      <td>16</td>\n",
       "      <td>56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>reproducibility_('Apple Mobility', 'UpToDate')_state</th>\n",
       "      <td>179</td>\n",
       "      <td>73</td>\n",
       "      <td>14</td>\n",
       "      <td>7</td>\n",
       "      <td>32</td>\n",
       "      <td>53</td>\n",
       "      <td>18</td>\n",
       "      <td>58</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>reproducibility_('UpToDate',)_state</th>\n",
       "      <td>78</td>\n",
       "      <td>17</td>\n",
       "      <td>6</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>49</td>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                   Total Events Early Warning  \\\n",
       "reproducibility_('Apple Mobility',)_state                   172            75   \n",
       "reproducibility_('Apple Mobility', 'UpToDate')_...          179            73   \n",
       "reproducibility_('UpToDate',)_state                          78            17   \n",
       "\n",
       "                                                   Sync Warning Late Warning  \\\n",
       "reproducibility_('Apple Mobility',)_state                    14            5   \n",
       "reproducibility_('Apple Mobility', 'UpToDate')_...           14            7   \n",
       "reproducibility_('UpToDate',)_state                           6            6   \n",
       "\n",
       "                                                   Event and EWS close to threshold  \\\n",
       "reproducibility_('Apple Mobility',)_state                                        21   \n",
       "reproducibility_('Apple Mobility', 'UpToDate')_...                               32   \n",
       "reproducibility_('UpToDate',)_state                                               0   \n",
       "\n",
       "                                                   Missed Outbreaks  \\\n",
       "reproducibility_('Apple Mobility',)_state                        57   \n",
       "reproducibility_('Apple Mobility', 'UpToDate')_...               53   \n",
       "reproducibility_('UpToDate',)_state                              49   \n",
       "\n",
       "                                                   Warning Associated to Activity but no Event  \\\n",
       "reproducibility_('Apple Mobility',)_state                                                   16   \n",
       "reproducibility_('Apple Mobility', 'UpToDate')_...                                          18   \n",
       "reproducibility_('UpToDate',)_state                                                          5   \n",
       "\n",
       "                                                   False Alarm  \n",
       "reproducibility_('Apple Mobility',)_state                   56  \n",
       "reproducibility_('Apple Mobility', 'UpToDate')_...          58  \n",
       "reproducibility_('UpToDate',)_state                          3  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# all experiments dataframe\n",
    "geo = 'state' \n",
    "\n",
    "if geo == 'state':\n",
    "    site_names = constants.site_names[:50]\n",
    "elif geo == 'county':\n",
    "    site_names = constants.site_names[50:] \n",
    "\n",
    "rows_ordered = [ 'Total Events', 'Early Warning', 'Sync Warning', 'Late Warning',\n",
    "       'Event and EWS close to threshold',\n",
    "       'Missed Outbreaks','Warning Associated to Activity but no Event','False Alarm']\n",
    "\n",
    "\n",
    "experiment_names = {}\n",
    "for dataset_combination in dataset_combinations[:3]:\n",
    "    #print(dataset_combination)\n",
    "    for model_type in ['ews']: #'ews', 'ensemble'\n",
    "        for name_signal in ['rt_parag_probability']:#[ 'rt_parag_probability']:\n",
    "            for source_type in ['cases']:\n",
    "                name_target = '{0}_{1}'.format(name_signal, source_type) # 'JHU_{0}'.format(source_type) #_corrected\n",
    "                experiment_names[dataset_combination] = 'reproducibility_{4}_{3}'.format(current_date, name_target, model_type, geo, dataset_combination)\n",
    "               \n",
    "\n",
    "dfs = []         \n",
    "names_display = []\n",
    "                \n",
    "for display_name, experiment_name in experiment_names.items():\n",
    "    print(experiment_name)\n",
    "    if 'probability' in experiment_name:\n",
    "        threshold = .95\n",
    "    else:\n",
    "        threshold = .95\n",
    "        \n",
    "    if 'lucas' in experiment_name:\n",
    "        df,_ = perform_analysis_singlesource(experiment_name, threshold=threshold, filter_date=None, site_names = site_names)\n",
    "        \n",
    "    else: \n",
    "        df,_ = perform_analysis(experiment_name, threshold=threshold, filter_date=None, date=None, site_names = site_names)\n",
    "        \n",
    "    dfs.append( df.T[rows_ordered].T[[experiment_name]].copy() )\n",
    "    names_display.append(display_name)\n",
    "\n",
    "dfs = pd.concat(dfs, axis=1)\n",
    "dfs = dfs.T.sort_values('Early Warning', ascending=False)\n",
    "dfs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
